1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
17 
18 #include <algorithm>
19 
20 #include "absl/container/flat_hash_map.h"
21 #include "absl/container/flat_hash_set.h"
22 #include "absl/memory/memory.h"
23 #include "tensorflow/compiler/xla/map_util.h"
24 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
25 #include "tensorflow/compiler/xla/types.h"
26 
27 namespace xla {
28 
Create(HloComputation * computation,string domain_kind)29 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
30     HloComputation* computation, string domain_kind) {
31   auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
32   TF_RETURN_IF_ERROR(domain_map->Populate(computation));
33   return std::move(domain_map);
34 }
35 
Create(HloModule * module,string domain_kind)36 /* static */ StatusOr<std::unique_ptr<HloDomainMap>> HloDomainMap::Create(
37     HloModule* module, string domain_kind) {
38   auto domain_map = absl::WrapUnique(new HloDomainMap(std::move(domain_kind)));
39   for (HloComputation* computation : module->computations()) {
40     TF_RETURN_IF_ERROR(domain_map->Populate(computation));
41   }
42   return std::move(domain_map);
43 }
44 
InSameDomain(const HloInstruction * instruction1,const HloInstruction * instruction2) const45 bool HloDomainMap::InSameDomain(const HloInstruction* instruction1,
46                                 const HloInstruction* instruction2) const {
47   int64 domain_id1 = GetDomainId(instruction1);
48   int64 domain_id2 = GetDomainId(instruction2);
49   return domain_id1 >= 0 && domain_id1 == domain_id2;
50 }
51 
GetDomainId(const HloInstruction * instruction) const52 int64 HloDomainMap::GetDomainId(const HloInstruction* instruction) const {
53   return FindOrDefault(instruction_to_domain_, instruction, -1);
54 }
55 
GetDomainMetadataId(const HloInstruction * instruction) const56 int64 HloDomainMap::GetDomainMetadataId(
57     const HloInstruction* instruction) const {
58   return FindOrDie(domain_metadata_id_, instruction);
59 }
60 
TryProcessEmptyDomain(HloInstruction * instruction)61 Status HloDomainMap::TryProcessEmptyDomain(HloInstruction* instruction) {
62   TF_RET_CHECK(instruction->opcode() == HloOpcode::kDomain);
63   // We only check operands, so we are sure to not process the empty domain from
64   // both sides.
65   for (HloInstruction* operand : instruction->unique_operands()) {
66     if (IsDomainInstruction(operand)) {
67       auto domain = absl::make_unique<DomainMetadata::Domain>();
68       domain->enter_domains.insert(operand);
69       domain->exit_domains.insert(instruction);
70       TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
71     }
72   }
73   if (instruction == instruction->parent()->root_instruction()) {
74     auto domain = absl::make_unique<DomainMetadata::Domain>();
75     domain->enter_domains.insert(instruction);
76     TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
77   }
78   return Status::OK();
79 }
80 
Populate(HloComputation * computation)81 Status HloDomainMap::Populate(HloComputation* computation) {
82   InstructionOrderMap instructions_post_order;
83   int64 count = 0;
84   for (HloInstruction* instruction : computation->MakeInstructionPostOrder()) {
85     instructions_post_order.insert(std::make_pair(instruction, count++));
86   }
87   for (HloInstruction* instruction : computation->instructions()) {
88     if (IsDomainInstruction(instruction)) {
89       // If this is a kDomain of the kind we are currently processing, check
90       // whether this is an "empty domain".
91       TF_RETURN_IF_ERROR(TryProcessEmptyDomain(instruction));
92       continue;
93     }
94     int64 domain_id = FindOrDefault(instruction_to_domain_, instruction, -1);
95     if (domain_id >= 0) {
96       // We have already processed this instruction.
97       continue;
98     }
99     TF_ASSIGN_OR_RETURN(std::unique_ptr<DomainMetadata::Domain> domain,
100                         CreateDomain(instruction, instructions_post_order));
101     TF_RETURN_IF_ERROR(InsertDomain(std::move(domain)));
102   }
103   TF_RETURN_IF_ERROR(PopulateDomainMetadataMap());
104   return Status::OK();
105 }
106 
PopulateDomainMetadataMap()107 Status HloDomainMap::PopulateDomainMetadataMap() {
108   auto hash = [](const DomainMetadata* m) { return m->Hash(); };
109   auto equal = [](const DomainMetadata* a, const DomainMetadata* b) {
110     return a->Matches(*b);
111   };
112   absl::flat_hash_map<const DomainMetadata*, int64, decltype(hash),
113                       decltype(equal)>
114       domain_metadata(1024, hash, equal);
115 
116   for (auto& domain : instruction_domains_) {
117     int64 domain_metadata_id = -1;
118     if (!domain->enter_domains.empty()) {
119       const HloInstruction* domain_instruction = *domain->enter_domains.begin();
120       domain_metadata_id =
121           domain_metadata
122               .insert({&domain_instruction->user_side_metadata(),
123                        domain_metadata.size() + 1})
124               .first->second;
125     } else if (!domain->exit_domains.empty()) {
126       const HloInstruction* domain_instruction = *domain->exit_domains.begin();
127       domain_metadata_id =
128           domain_metadata
129               .insert({&domain_instruction->operand_side_metadata(),
130                        domain_metadata.size() + 1})
131               .first->second;
132     } else {
133       domain_metadata_id = 0;
134     }
135     TF_RET_CHECK(domain_metadata_id >= 0);
136     for (HloInstruction* instruction : domain->instructions) {
137       domain_metadata_id_[instruction] = domain_metadata_id;
138     }
139   }
140   return Status::OK();
141 }
142 
InsertDomain(std::unique_ptr<DomainMetadata::Domain> domain)143 Status HloDomainMap::InsertDomain(
144     std::unique_ptr<DomainMetadata::Domain> domain) {
145   int64 domain_id = instruction_domains_.size();
146   instruction_domains_.push_back(std::move(domain));
147   for (HloInstruction* instruction : instruction_domains_.back()->reach_set) {
148     instruction_to_domain_[instruction] = domain_id;
149   }
150   return Status::OK();
151 }
152 
ExpandDomain(HloInstruction * instruction,DomainMetadata::Domain * domain) const153 Status HloDomainMap::ExpandDomain(HloInstruction* instruction,
154                                   DomainMetadata::Domain* domain) const {
155   std::vector<HloInstruction*> in_queue;
156   in_queue.push_back(instruction);
157   while (!in_queue.empty()) {
158     HloInstruction* current_instruction = in_queue.back();
159     in_queue.pop_back();
160     if (domain->reach_set.insert(current_instruction).second) {
161       // We should not be finding instructions with assigned domain here.
162       // If we assigned a domain to the instruction, it means that all the
163       // instructions reached by it, should have a domain as well.
164       int64 domain_id =
165           FindOrDefault(instruction_to_domain_, current_instruction, -1);
166       TF_RET_CHECK(domain_id < 0)
167           << "Instruction " << current_instruction->ToString()
168           << " already has domain " << domain_id;
169       for (HloInstruction* operand : current_instruction->operands()) {
170         if (IsDomainInstruction(operand)) {
171           // The reach set instruction is a user of the domain instruction
172           // (the instruction sees the kDomain as operand).
173           // IOW the dataflow enters the domain through the kDomain instruction.
174           domain->enter_domains.insert(operand);
175         } else {
176           in_queue.push_back(operand);
177         }
178       }
179       for (HloInstruction* user : current_instruction->users()) {
180         if (IsDomainInstruction(user)) {
181           // The reach set instruction is an operand of the domain instruction
182           // (the instruction sees the kDomain as user).
183           // IOW the dataflow exits the domain through the kDomain instruction.
184           domain->exit_domains.insert(user);
185         } else {
186           in_queue.push_back(user);
187         }
188       }
189     }
190   }
191   return Status::OK();
192 }
193 
CreateDomain(HloInstruction * instruction,const InstructionOrderMap & instructions_order) const194 StatusOr<std::unique_ptr<DomainMetadata::Domain>> HloDomainMap::CreateDomain(
195     HloInstruction* instruction,
196     const InstructionOrderMap& instructions_order) const {
197   auto domain = absl::make_unique<DomainMetadata::Domain>();
198   TF_RETURN_IF_ERROR(ExpandDomain(instruction, domain.get()));
199   domain->instructions =
200       MakeNonDomainInstructions(domain->reach_set, instructions_order);
201   return std::move(domain);
202 }
203 
IsDomainInstruction(const HloInstruction * instruction) const204 bool HloDomainMap::IsDomainInstruction(
205     const HloInstruction* instruction) const {
206   if (instruction->opcode() != HloOpcode::kDomain) {
207     return false;
208   }
209   if (!domain_kind_.empty()) {
210     if (instruction->user_side_metadata().Kind() != domain_kind_) {
211       return false;
212     }
213     // Both user and operand side of the metadata must be of the same kind.
214     CHECK(instruction->operand_side_metadata().Kind() == domain_kind_)
215         << "Instruction " << instruction->ToString()
216         << " has mismatching metadata kinds";
217   }
218   return true;
219 }
220 
221 /* static */ std::vector<HloInstruction*>
MakeNonDomainInstructions(const absl::flat_hash_set<HloInstruction * > & instruction_set,const InstructionOrderMap & instructions_order)222 HloDomainMap::MakeNonDomainInstructions(
223     const absl::flat_hash_set<HloInstruction*>& instruction_set,
224     const InstructionOrderMap& instructions_order) {
225   std::vector<HloInstruction*> instructions;
226   instructions.reserve(instruction_set.size());
227   for (HloInstruction* instruction : instruction_set) {
228     if (instruction->opcode() != HloOpcode::kDomain) {
229       instructions.push_back(instruction);
230     }
231   }
232   // sort instructions according to instructions_order
233   absl::c_sort(instructions,
234                [&instructions_order](HloInstruction* a, HloInstruction* b) {
235                  return instructions_order.at(a) < instructions_order.at(b);
236                });
237   return instructions;
238 }
239 
240 }  // namespace xla
241