1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17 
18 #include <list>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38 
39 namespace xla {
40 
41 namespace {
42 
43 // Find and combine identical constants. Constants are identical if they have
44 // the same type and value.
CombineConstants(HloComputation * computation,bool is_layout_sensitive)45 StatusOr<bool> CombineConstants(HloComputation* computation,
46                                 bool is_layout_sensitive) {
47   TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
48   // Map from ShortDebugString of the layoutless shape of the constant to the
49   // set of constant instructions with that shape. Layoutless shape is used to
50   // bin possible common constants together to reduce number of constant
51   // comparisons. If we end up having too many constant comparisons, a more
52   // precise binning might have to be used.
53   std::multimap<string, HloInstruction*> constants;
54   int64 combined = 0;
55   auto inst_it = computation->instructions().begin();
56   while (inst_it != computation->instructions().end()) {
57     HloInstruction* instruction = *inst_it;
58 
59     // Advance list iterator before loop body because iterator may be
60     // invalidated due to deletion.
61     ++inst_it;
62 
63     if (instruction->opcode() == HloOpcode::kConstant) {
64       Shape shape = instruction->shape();
65       if (!is_layout_sensitive) {
66         LayoutUtil::ClearLayout(&shape);
67       }
68       string shape_string = shape.ShortDebugString();
69 
70       // Compare against all constants with the same shape
71       auto range = constants.equal_range(shape_string);
72       HloInstruction* match = nullptr;
73       for (auto it = range.first; it != range.second; ++it) {
74         if (instruction->literal() == it->second->literal() &&
75             domain_map->InSameDomain(it->second, instruction)) {
76           match = it->second;
77           break;
78         }
79       }
80       if (match == nullptr) {
81         constants.emplace(shape_string, instruction);
82       } else {
83         // Match found, replace this instruction with the one in the multimap.
84         TF_CHECK_OK(instruction->ReplaceAllUsesWith(match));
85         TF_CHECK_OK(computation->RemoveInstruction(instruction));
86         ++combined;
87       }
88     }
89   }
90   VLOG(4) << "Combined " << combined << " constants in " << computation->name()
91           << " computation";
92   return combined > 0;
93 }
94 
95 // An instruction is considered to be equivalent to another only if they
96 // share the exact same set of operands.
CseHash(const HloInstruction * instruction)97 int64 CseHash(const HloInstruction* instruction) {
98   int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
99   hash = tensorflow::Hash64Combine(
100       hash, instruction->opcode() == HloOpcode::kGetTupleElement
101                 ? instruction->tuple_index()
102                 : -1);
103   for (auto operand : instruction->operands()) {
104     hash = tensorflow::Hash64Combine(hash, operand->unique_id());
105   }
106   if (instruction->opcode() == HloOpcode::kConstant) {
107     hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
108   }
109   return hash;
110 }
111 
112 }  // namespace
113 
Run(HloModule * module)114 StatusOr<bool> HloCSE::Run(HloModule* module) {
115   bool changed = false;
116   const std::function<bool(const HloInstruction*, const HloInstruction*)>
117       eq_instructions = std::equal_to<const HloInstruction*>();
118   const std::function<bool(const HloComputation*, const HloComputation*)>
119       eq_computations = [](const HloComputation* lhs,
120                            const HloComputation* rhs) { return *lhs == *rhs; };
121 
122   auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
123     return lhs->Identical(*rhs, eq_instructions, eq_computations,
124                           is_layout_sensitive_);
125   };
126 
127   for (auto* computation : module->computations()) {
128     if (only_fusion_computations_ && !computation->IsFusionComputation()) {
129       continue;
130     }
131 
132     TF_ASSIGN_OR_RETURN(bool combined,
133                         CombineConstants(computation, is_layout_sensitive_));
134     changed |= combined;
135 
136     // HLO instructions are grouped into equivalency classes by using the
137     // cse_equal predicate defined above. This set holds a representative
138     // instruction for each class.
139     absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
140                         decltype(cse_equal)>
141         representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
142                         cse_equal);
143     for (auto instruction : computation->MakeInstructionPostOrder()) {
144       // If the instruction has zero operands (constants, parameters, etc.) skip
145       // over it.
146       if (instruction->operand_count() == 0) {
147         continue;
148       }
149       // Skip instructions which have side effects.
150       if (instruction->HasSideEffect()) {
151         continue;
152       }
153 
154       auto it = representatives.find(instruction);
155       if (it != representatives.end()) {
156         HloInstruction* equivalent_instruction = *it;
157         TF_RETURN_IF_ERROR(
158             instruction->ReplaceAllUsesWith(equivalent_instruction));
159         TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
160         changed = true;
161         continue;
162       }
163       representatives.insert(instruction);
164     }
165   }
166   return changed;
167 }
168 
169 }  // namespace xla
170