1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_sharding_metadata.h"
17 
18 #include "absl/memory/memory.h"
19 #include "tensorflow/compiler/xla/service/hlo_computation.h"
20 #include "tensorflow/compiler/xla/shape_tree.h"
21 #include "tensorflow/compiler/xla/shape_util.h"
22 
23 namespace xla {
24 
25 namespace {
26 
27 // AssignmentKind and kUnassignedDevice are used during tuple domain sharding
28 // propagation in order to distinguish among three cases:
29 // kUnassigned: no assignment has occurred
30 // kAssigned: at least an assignment has occurred
31 // kConflict: no assignment has occurred because of conflicting propagations,
32 // which occurs when multiple users of an instruction have different
33 // shardings.
34 enum class AssignmentKind { kUnassigned, kAssigned, kConflict };
35 
36 // kUnassignedDevice can only be assigned to tuple leaf shardings to indicate
37 // absence of sharding information for that particular sub-sharding during
38 // sharding propagation. It is used to be able to express tuple shardings with
39 // partial information. At the end of the propagation the sharding of
40 // tuple-shaped instructions using kUnassignedDevice's is cleared.
41 // TODO(b/112883246): Centralized enum of reserved devices.
42 constexpr int64 kUnassignedDevice = -2;
43 
44 struct PassThrough {
PassThroughxla::__anon96d769990111::PassThrough45   PassThrough(HloInstruction* user, HloInstruction* operand)
46       : user(user), operand(operand) {}
47 
48   HloInstruction* user = nullptr;
49   HloInstruction* operand = nullptr;
50 };
51 
SetSingleSharding(HloInstruction * instruction,const HloSharding & sharding)52 void SetSingleSharding(HloInstruction* instruction,
53                        const HloSharding& sharding) {
54   VLOG(4) << "  " << instruction->name() << " to " << sharding;
55   instruction->set_single_sharding(sharding);
56 }
57 
ShardingMatches(const HloSharding & sharding1,const HloSharding & sharding2)58 bool ShardingMatches(const HloSharding& sharding1,
59                      const HloSharding& sharding2) {
60   auto single_sharding1 = sharding1.ExtractSingleSharding();
61   if (single_sharding1) {
62     auto single_sharding2 = sharding2.ExtractSingleSharding();
63     if (single_sharding2) {
64       return *single_sharding1 == single_sharding2;
65     }
66   }
67   // Anything which is not unique across all elements, gets a full sharding
68   // compare.
69   return sharding1 == sharding2;
70 }
71 
72 // When we create domains, they are never "empty", where with empty we mean
73 // that a kDomain instruction has as operand another kDomain instruction of the
74 // same kind.
75 // But when the HLO optimizations are run, empty domains can be created.
76 // For example:
77 //
78 //  Domain(device=None, device=0) ->
79 //    Tuple(device=0) ->
80 //      GTE(device=0) ->
81 //        Domain(device=0, device=None)
82 //
83 // In that case the tuple simplifier could create something like:
84 //
85 //  Domain(device=None, device=0) -> Domain(device=0, device=None)
86 //
87 // Which is a so called empty domain.
88 // In the case above, crossing an empty domain which was transiting through
89 // device 0, requires the normalization phase to fixup the empty domain by
90 // adding back a Tuple+GTE pair with the proper device.
91 // One particular case where this can create problems is the result of the
92 // entry computation, where the GTE assignments are used by TF to tell the
93 // XLA where the results should be sent.
LocatePassThroughDomainLinks(const DomainMetadata::Domain & domain)94 std::vector<PassThrough> LocatePassThroughDomainLinks(
95     const DomainMetadata::Domain& domain) {
96   std::vector<PassThrough> pass_through;
97   for (HloInstruction* instruction : domain.enter_domains) {
98     CHECK(instruction->opcode() == HloOpcode::kDomain)
99         << "Instruction is not a kDomain: " << instruction->ToString();
100     for (HloInstruction* user : instruction->users()) {
101       if (user->opcode() == HloOpcode::kDomain &&
102           domain.exit_domains.contains(user)) {
103         pass_through.emplace_back(user, instruction);
104         VLOG(2) << "Found passthrough domain link:";
105         VLOG(2) << "  " << user->ToString();
106         VLOG(2) << "  " << instruction->ToString();
107       }
108     }
109     if (instruction == instruction->parent()->root_instruction()) {
110       pass_through.emplace_back(nullptr, instruction);
111       VLOG(2) << "Found passthrough domain link:";
112       VLOG(2) << "  <root>";
113       VLOG(2) << "  " << instruction->ToString();
114     }
115   }
116   return pass_through;
117 }
118 
FixupPassThroughDomainLinks(const DomainMetadata::Domain & domain,const HloSharding & sharding)119 Status FixupPassThroughDomainLinks(const DomainMetadata::Domain& domain,
120                                    const HloSharding& sharding) {
121   for (auto& pass_through : LocatePassThroughDomainLinks(domain)) {
122     HloInstruction* tuple = pass_through.operand->parent()->AddInstruction(
123         HloInstruction::CreateTuple({pass_through.operand}));
124     HloInstruction* gte = pass_through.operand->parent()->AddInstruction(
125         HloInstruction::CreateGetTupleElement(pass_through.operand->shape(),
126                                               tuple, 0));
127     gte->set_sharding(sharding);
128     if (pass_through.user != nullptr) {
129       TF_RETURN_IF_ERROR(
130           pass_through.operand->ReplaceUseWith(pass_through.user, gte));
131     } else {
132       pass_through.operand->parent()->set_root_instruction(gte);
133     }
134   }
135   return Status::OK();
136 }
137 
138 // For tuple shardings if every element have the same sharsing then we want to
139 // treat them as single element sharsings to insert less domain separation as a
140 // domain can prevent some optimizations and we want to minimize that from
141 // happening.
CloneShardingForDomain(std::shared_ptr<const HloSharding> sharding)142 std::shared_ptr<const HloSharding> CloneShardingForDomain(
143     std::shared_ptr<const HloSharding> sharding) {
144   auto single_sharding = sharding->ExtractSingleSharding();
145   if (!single_sharding) {
146     return sharding;
147   }
148   return std::make_shared<const HloSharding>(*single_sharding);
149 }
150 
ApplyDomainSingleSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)151 Status ApplyDomainSingleSharding(const DomainMetadata::Domain& domain,
152                                  const HloSharding& sharding) {
153   VLOG(4) << "Applying " << sharding << " sharding";
154   for (HloInstruction* instruction : domain.instructions) {
155     // We only change instructions without sharding, since otherwise we might
156     // mess up with eventual HLO passes which has knowledge of it.
157     if (!instruction->has_sharding()) {
158       SetSingleSharding(instruction, sharding);
159     } else {
160       VLOG(4) << "  " << instruction->name() << " already has sharding "
161               << instruction->sharding();
162     }
163   }
164   return Status::OK();
165 }
166 
167 // Return the ShapeTree<HloSharding> of the user argument. The user argument
168 // is assumed to be a user of the instruction argument.
169 // If user is a tuple instruction, return the tuple subsharding corresponding to
170 // the operand matching the instruction argument, because that is the
171 // subsharding corresponding to instruction.
GetShardingTreeFromUser(const HloInstruction & instruction,const HloInstruction & user)172 StatusOr<ShapeTree<HloSharding>> GetShardingTreeFromUser(
173     const HloInstruction& instruction, const HloInstruction& user) {
174   if (user.opcode() == HloOpcode::kTuple) {
175     return user.sharding()
176         .GetSubSharding(user.shape(), {user.operand_index(&instruction)})
177         .AsShapeTree(instruction.shape());
178   }
179   return user.sharding().AsShapeTree(user.shape());
180 }
181 
182 // Assign rhs to lhs. If rhs is unassigned (assigned to kUnassignedDevice)
183 // then no assignment is made. Therefore kUnassignedDevice is never propagated.
184 // kConflict is returned if lhs is already assigned and rhs is assigned to a
185 // different device.
AssignLeafSharding(HloSharding * lhs,const HloSharding & rhs)186 StatusOr<AssignmentKind> AssignLeafSharding(HloSharding* lhs,
187                                             const HloSharding& rhs) {
188   TF_RET_CHECK(!lhs->IsTuple() && !rhs.IsTuple());
189   if (rhs.UsesDevice(kUnassignedDevice)) {
190     return AssignmentKind::kUnassigned;
191   }
192   if (lhs->UsesDevice(kUnassignedDevice)) {
193     *lhs = rhs;
194     return AssignmentKind::kAssigned;
195   }
196   return lhs->UniqueDevice() != rhs.UniqueDevice()
197              ? AssignmentKind::kConflict
198              : AssignmentKind::kUnassigned;
199 }
200 
201 // Assigns the whole rhs tree to lhs_tree, starting at lhs_it.
202 // In case of conflicting assignment AssignmentKind::kConflict is returned. In
203 // this case lhs_tree is partially assigned, up to the conflicting leaf. It is
204 // up to the caller to discard the partial assignment in case of conflict.
AssignTreeSharding(ShapeTree<HloSharding> * lhs_tree,ShapeTree<HloSharding>::iterator lhs_it,const ShapeTree<HloSharding> & rhs_tree)205 StatusOr<AssignmentKind> AssignTreeSharding(
206     ShapeTree<HloSharding>* lhs_tree, ShapeTree<HloSharding>::iterator lhs_it,
207     const ShapeTree<HloSharding>& rhs_tree) {
208   AssignmentKind assigned = AssignmentKind::kUnassigned;
209   auto rhs_it = rhs_tree.begin();
210   for (; lhs_it != lhs_tree->end() && rhs_it != rhs_tree.end();
211        ++lhs_it, ++rhs_it) {
212     // TODO(b/112885211): Add ShapeTree::IsLeaf(const ShapeTreeIterator &it)
213     if (rhs_tree.IsLeaf(rhs_it->first)) {
214       TF_RET_CHECK(lhs_tree->IsLeaf(lhs_it->first));
215       TF_ASSIGN_OR_RETURN(AssignmentKind sub_assigned,
216                           AssignLeafSharding(&lhs_it->second, rhs_it->second));
217       if (sub_assigned == AssignmentKind::kConflict) {
218         // In case of conflict we return conflict to the caller. At this point
219         // partial assignments to lhs_tree may have been made already. It is up
220         // to the caller to discard the partial assignment in case of conflict.
221         return AssignmentKind::kConflict;
222       } else if (sub_assigned == AssignmentKind::kAssigned) {
223         assigned = sub_assigned;
224       }
225     }
226   }
227   TF_RET_CHECK(rhs_it == rhs_tree.end());
228   return assigned;
229 }
230 
ApplyShardingFromUsers(HloInstruction * instruction,const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)231 StatusOr<bool> ApplyShardingFromUsers(HloInstruction* instruction,
232                                       const DomainMetadata::Domain& domain,
233                                       const HloSharding& domain_sharding) {
234   if (instruction->users().empty()) {
235     // No sharding from users, use domain_sharding, after checking
236     // compatibility.
237     TF_RET_CHECK(instruction->shape().IsTuple() &&
238                  ShapeUtil::GetLeafCount(instruction->shape()) ==
239                      domain_sharding.tuple_elements().size());
240     instruction->set_sharding(domain_sharding);
241     return true;
242   }
243   AssignmentKind assigned = AssignmentKind::kUnassigned;
244   // The sharding_tree leaves are initialized to kUnassignedDevice. Only Tuple
245   // subshardings can result in a final sharding assignment containing
246   // kUnassignedDevice leaves, in case some tuple indexes are not used, or are
247   // used by users that don't have a sharding.
248   // Non-tuple shardings are either assigned to a real sharding, or are not
249   // assigned at all. As such they will never get assigned to kUnassignedDevice.
250   // In any case, kUnassignedDevice is never propagated, from the implementation
251   // of AssignLeafSharding.
252   ShapeTree<HloSharding> sharding_tree(
253       instruction->shape(), HloSharding::AssignDevice(kUnassignedDevice));
254   for (HloInstruction* user : instruction->users()) {
255     if (user->opcode() == HloOpcode::kDomain &&
256         domain.exit_domains.contains(user)) {
257       // If a user is a domain and it is registered in the domain exits, then
258       // the instruction sharding is taken directly from the domain, and no
259       // further users need to be visited.
260       instruction->set_sharding(domain_sharding);
261       return true;
262     }
263     if (!user->has_sharding()) {
264       continue;
265     }
266     AssignmentKind sub_assigned = AssignmentKind::kUnassigned;
267     TF_ASSIGN_OR_RETURN(ShapeTree<HloSharding> user_sharding_tree,
268                         GetShardingTreeFromUser(*instruction, *user));
269     if (instruction->shape().IsTuple()) {
270       // For tuple-shaped instructions collect individual tuple subshardings
271       // from the uses, and then combine them into the tuple sharding.
272       // If the user is a GTE its sharding concerns only the subtree of
273       // sharding_tree at index user->tuple_index, otherwise the whole
274       // sharding_tree is affected.
275       ShapeTree<HloSharding>::iterator sharding_tree_begin =
276           user->opcode() == HloOpcode::kGetTupleElement
277               ? sharding_tree.find({user->tuple_index()})
278               : sharding_tree.begin();
279       TF_ASSIGN_OR_RETURN(
280           sub_assigned, AssignTreeSharding(&sharding_tree, sharding_tree_begin,
281                                            user_sharding_tree));
282     } else {
283       // Non-tuple shape: assign common users sharding.
284       TF_RET_CHECK(user_sharding_tree.leaf_count() == 1)
285           << "Expected non-tuple user sharding";
286       TF_ASSIGN_OR_RETURN(
287           sub_assigned,
288           AssignTreeSharding(&sharding_tree, sharding_tree.begin(),
289                              user_sharding_tree));
290     }
291 
292     if (sub_assigned == AssignmentKind::kConflict) {
293       // In case of conflict we don't assign any sharding.
294       return false;
295     } else if (sub_assigned == AssignmentKind::kAssigned) {
296       assigned = sub_assigned;
297     }
298   }
299 
300   if (assigned == AssignmentKind::kAssigned) {
301     if (instruction->shape().IsTuple()) {
302       instruction->set_sharding(HloSharding::Tuple(sharding_tree));
303     } else {
304       TF_RET_CHECK(sharding_tree.leaf_count() == 1);
305       instruction->set_sharding(sharding_tree.leaf_begin()->second);
306     }
307     return true;
308   }
309   return false;
310 }
311 
312 // Tries to propagate the sharding information into the instructions that are
313 // part of the domain, in a reverse post order manner (users propoagate to
314 // instruction).
ApplyDomainShardingPass(const DomainMetadata::Domain & domain,const HloSharding & domain_sharding)315 StatusOr<int64> ApplyDomainShardingPass(const DomainMetadata::Domain& domain,
316                                         const HloSharding& domain_sharding) {
317   int64 assigned = 0;
318   // domain.instructions are ordered in a post-order manner. As we do
319   // user->operand propagation we process instructions in reverse order. In so
320   // doing we are guaranteed to process all users before their operands.
321   for (auto it = domain.instructions.rbegin(); it != domain.instructions.rend();
322        ++it) {
323     HloInstruction* instruction = *it;
324     if (instruction->has_sharding()) {
325       continue;
326     }
327     // Take the sharding from the users.
328     TF_ASSIGN_OR_RETURN(
329         bool instruction_assigned,
330         ApplyShardingFromUsers(instruction, domain, domain_sharding));
331     if (instruction_assigned) {
332       ++assigned;
333       VLOG(4) << "  " << instruction->name() << " to sharding "
334               << instruction->sharding();
335     }
336   }
337   return assigned;
338 }
339 
ApplyDomainSharding(const DomainMetadata::Domain & domain,const HloSharding & sharding)340 Status ApplyDomainSharding(const DomainMetadata::Domain& domain,
341                            const HloSharding& sharding) {
342   // None of the external normalizers handled the domain sharding, try to see
343   // whether this is a single sharding first.
344   auto single_sharding = sharding.ExtractSingleSharding();
345   if (single_sharding) {
346     // Shortcut the simple case. We have a unique sharding, so we call
347     // the ApplyDomainSingleSharding() API which will apply array or tuple
348     // shaped sharding to the domain instructions.
349     return ApplyDomainSingleSharding(domain, *single_sharding);
350   }
351   VLOG(1) << "Assigning non-trivial sharding " << sharding;
352   TF_RETURN_IF_ERROR(ApplyDomainShardingPass(domain, sharding).status());
353 
354   int64 unassigned = 0;
355   for (HloInstruction* instruction : domain.instructions) {
356     if (!instruction->has_sharding()) {
357       LOG(WARNING) << "Unassigned instruction: " << instruction->ToString();
358       ++unassigned;
359     } else {
360       // Un-set sharding of tuples whose sub-sgardings are assigned to
361       // kUnassignedDevice. Indeed in case of doubt it is better to leave the
362       // entire tuple unassigned, and let the device placer decide for it.
363       if (instruction->sharding().UsesDevice(kUnassignedDevice)) {
364         TF_RET_CHECK(instruction->shape().IsTuple())
365             << "Only tuples can have kUnassignedDevice sub shardings";
366         instruction->clear_sharding();
367       }
368     }
369   }
370   // Should we error out if unassigned > 0?
371   return Status::OK();
372 }
373 
ExtractOriginalCommonSharding(absl::Span<HloInstruction * const> instructions)374 StatusOr<std::shared_ptr<const HloSharding>> ExtractOriginalCommonSharding(
375     absl::Span<HloInstruction* const> instructions) {
376   // If we are here, all the instructions being passed had the same sharding
377   // (or no sharding), by the means of the ShardingMatches() API.
378   // As such, no kDomain was inserted, and here we are asked to extract the
379   // original common sharding.
380   // All the instructions passed to this API are part of the same computation.
381   std::shared_ptr<const HloSharding> sharding;
382   for (HloInstruction* instruction : instructions) {
383     if (instruction->has_sharding()) {
384       if (sharding == nullptr) {
385         sharding = instruction->sharding_ptr();
386       } else {
387         TF_RET_CHECK(ShardingMatches(*sharding, instruction->sharding()))
388             << "Sharding " << *sharding << " does not match the one in "
389             << instruction->ToString();
390       }
391     }
392   }
393   if (sharding == nullptr) {
394     return std::shared_ptr<const HloSharding>();
395   }
396   VLOG(4) << "Extracted sharding is " << *sharding;
397   return CloneShardingForDomain(sharding);
398 }
399 
400 }  // namespace
401 
Clone() const402 std::unique_ptr<DomainMetadata> ShardingMetadata::Clone() const {
403   std::unique_ptr<HloSharding> sharding;
404   if (sharding_ != nullptr) {
405     sharding = absl::make_unique<HloSharding>(*sharding_);
406   }
407   return absl::make_unique<ShardingMetadata>(std::move(sharding));
408 }
409 
Matches(const DomainMetadata & other) const410 bool ShardingMetadata::Matches(const DomainMetadata& other) const {
411   const ShardingMetadata* other_ptr =
412       dynamic_cast<const ShardingMetadata*>(&other);
413   if (other_ptr == nullptr) {
414     // If other is not a ShardingMetadata, then it is clearly a no match.
415     return false;
416   }
417   if (sharding_ == nullptr) {
418     return other_ptr->sharding_ == nullptr;
419   }
420   return other_ptr->sharding_ != nullptr
421              ? ShardingMatches(*sharding_, *other_ptr->sharding_)
422              : false;
423 }
424 
Hash() const425 size_t ShardingMetadata::Hash() const {
426   if (sharding_ != nullptr) {
427     return sharding_->Hash();
428   }
429   return static_cast<size_t>(0x297814aaad196e6dULL);
430 }
431 
ToString() const432 string ShardingMetadata::ToString() const {
433   return sharding_ != nullptr ? sharding_->ToString() : "{}";
434 }
435 
436 /*static*/ StatusOr<const ShardingMetadata*>
ToShardingMetadata(const DomainMetadata * metadata)437 ShardingMetadata::ToShardingMetadata(const DomainMetadata* metadata) {
438   if (metadata->Kind() != ShardingMetadata::KindName()) {
439     return Status(
440         tensorflow::error::INVALID_ARGUMENT,
441         "ShardingMetadata normalizer called with incorrect domain metadata");
442   }
443   return static_cast<const ShardingMetadata*>(metadata);
444 }
445 
NormalizeShardingDomain(const DomainMetadata::Domain & domain,const DomainMetadata * metadata)446 Status ShardingMetadata::NormalizeShardingDomain(
447     const DomainMetadata::Domain& domain, const DomainMetadata* metadata) {
448   if (metadata != nullptr) {
449     TF_ASSIGN_OR_RETURN(const auto& sharding_metadata,
450                         ToShardingMetadata(metadata));
451     const HloSharding* sharding = sharding_metadata->sharding();
452     if (sharding != nullptr) {
453       VLOG(4) << "Normalizing sharding to " << sharding->ToString() << ":";
454       TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
455       TF_RETURN_IF_ERROR(FixupPassThroughDomainLinks(domain, *sharding));
456     }
457   } else {
458     TF_ASSIGN_OR_RETURN(std::shared_ptr<const HloSharding> sharding,
459                         ExtractOriginalCommonSharding(domain.instructions));
460     if (sharding != nullptr) {
461       VLOG(4) << "Normalizing sharding-less domain to " << sharding->ToString();
462       TF_RETURN_IF_ERROR(ApplyDomainSharding(domain, *sharding));
463     } else {
464       VLOG(1) << "Unable to find common sharding";
465     }
466   }
467   return Status::OK();
468 }
469 
470 // Creates a kDomain instruction to be placed between instruction and operand.
471 // The kDomain instruction will be created only if the sharding differ between
472 // the instruction and the operand.
operator ()(HloInstruction * instruction,HloInstruction * root,HloInstruction * operand)473 HloInstruction* ShardingDomainCreator::operator()(HloInstruction* instruction,
474                                                   HloInstruction* root,
475                                                   HloInstruction* operand) {
476   auto instruction_sharding = instruction->sharding_ptr();
477   auto root_sharding = root->sharding_ptr();
478   // No need for domain if they both have no sharding.
479   if (instruction_sharding == nullptr && root_sharding == nullptr) {
480     return nullptr;
481   }
482   // No need for domain if they match.
483   if (instruction_sharding != nullptr && root_sharding != nullptr &&
484       ShardingMatches(*instruction_sharding, *root_sharding)) {
485     return nullptr;
486   }
487 
488   if (instruction_sharding != nullptr) {
489     instruction_sharding = CloneShardingForDomain(instruction_sharding);
490   }
491   if (root_sharding != nullptr) {
492     root_sharding = CloneShardingForDomain(root_sharding);
493   }
494 
495   auto it = domain_cse_map_.find({operand, instruction_sharding});
496   if (it != domain_cse_map_.end()) {
497     return it->second;
498   }
499 
500   VLOG(3) << "Creating domain:";
501   VLOG(3) << "  Instruction: " << instruction->name();
502   VLOG(3) << "  Operand: " << operand->name();
503   VLOG(3) << "    User side sharding: "
504           << (instruction_sharding != nullptr ? instruction_sharding->ToString()
505                                               : "None");
506   VLOG(3) << "    Operand side sharding: "
507           << (root_sharding != nullptr ? root_sharding->ToString() : "None");
508 
509   HloInstruction* domain =
510       operand->parent()->AddInstruction(HloInstruction::CreateDomain(
511           operand->shape(), operand,
512           absl::make_unique<ShardingMetadata>(root_sharding),
513           absl::make_unique<ShardingMetadata>(instruction_sharding)));
514   domain_cse_map_.emplace(DomainCseMapKey{operand, instruction_sharding},
515                           domain);
516   return domain;
517 }
518 
operator ==(const ShardingDomainCreator::DomainCseMapKey & other) const519 bool ShardingDomainCreator::DomainCseMapKey::operator==(
520     const ShardingDomainCreator::DomainCseMapKey& other) const {
521   if (instruction != other.instruction) {
522     return false;
523   }
524   if (sharding == nullptr && other.sharding == nullptr) {
525     return true;
526   }
527   if (sharding == nullptr || other.sharding == nullptr) {
528     return false;
529   }
530   return *sharding == *other.sharding;
531 }
532 
operator ()(const ShardingDomainCreator::DomainCseMapKey & key) const533 size_t ShardingDomainCreator::DomainCseMapHasher::operator()(
534     const ShardingDomainCreator::DomainCseMapKey& key) const {
535   return tensorflow::Hash64Combine(
536       std::hash<const HloInstruction*>{}(key.instruction),
537       key.sharding ? key.sharding->Hash()
538                    : static_cast<size_t>(0x297814aaad196e6dULL));
539 }
540 
541 }  // namespace xla
542