1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_cse.h"
17
18 #include <list>
19 #include <map>
20 #include <memory>
21 #include <set>
22 #include <string>
23 #include <utility>
24 #include <vector>
25
26 #include "absl/container/flat_hash_set.h"
27 #include "absl/container/inlined_vector.h"
28 #include "tensorflow/compiler/xla/layout_util.h"
29 #include "tensorflow/compiler/xla/literal.h"
30 #include "tensorflow/compiler/xla/service/hlo_computation.h"
31 #include "tensorflow/compiler/xla/service/hlo_domain_map.h"
32 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
33 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
34 #include "tensorflow/compiler/xla/shape_util.h"
35 #include "tensorflow/compiler/xla/types.h"
36 #include "tensorflow/core/lib/core/errors.h"
37 #include "tensorflow/core/lib/hash/hash.h"
38
39 namespace xla {
40
41 namespace {
42
43 // Find and combine identical constants. Constants are identical if they have
44 // the same type and value.
CombineConstants(HloComputation * computation,bool is_layout_sensitive)45 StatusOr<bool> CombineConstants(HloComputation* computation,
46 bool is_layout_sensitive) {
47 TF_ASSIGN_OR_RETURN(auto domain_map, HloDomainMap::Create(computation, ""));
48 // Map from ShortDebugString of the layoutless shape of the constant to the
49 // set of constant instructions with that shape. Layoutless shape is used to
50 // bin possible common constants together to reduce number of constant
51 // comparisons. If we end up having too many constant comparisons, a more
52 // precise binning might have to be used.
53 std::multimap<string, HloInstruction*> constants;
54 int64 combined = 0;
55 auto inst_it = computation->instructions().begin();
56 while (inst_it != computation->instructions().end()) {
57 HloInstruction* instruction = *inst_it;
58
59 // Advance list iterator before loop body because iterator may be
60 // invalidated due to deletion.
61 ++inst_it;
62
63 if (instruction->opcode() == HloOpcode::kConstant) {
64 Shape shape = instruction->shape();
65 if (!is_layout_sensitive) {
66 LayoutUtil::ClearLayout(&shape);
67 }
68 string shape_string = shape.ShortDebugString();
69
70 // Compare against all constants with the same shape
71 auto range = constants.equal_range(shape_string);
72 HloInstruction* match = nullptr;
73 for (auto it = range.first; it != range.second; ++it) {
74 if (instruction->literal() == it->second->literal() &&
75 domain_map->InSameDomain(it->second, instruction)) {
76 match = it->second;
77 break;
78 }
79 }
80 if (match == nullptr) {
81 constants.emplace(shape_string, instruction);
82 } else {
83 // Match found, replace this instruction with the one in the multimap.
84 TF_CHECK_OK(instruction->ReplaceAllUsesWith(match));
85 TF_CHECK_OK(computation->RemoveInstruction(instruction));
86 ++combined;
87 }
88 }
89 }
90 VLOG(4) << "Combined " << combined << " constants in " << computation->name()
91 << " computation";
92 return combined > 0;
93 }
94
95 // An instruction is considered to be equivalent to another only if they
96 // share the exact same set of operands.
CseHash(const HloInstruction * instruction)97 int64 CseHash(const HloInstruction* instruction) {
98 int64 hash = std::hash<int64>()(static_cast<int64>(instruction->opcode()));
99 hash = tensorflow::Hash64Combine(
100 hash, instruction->opcode() == HloOpcode::kGetTupleElement
101 ? instruction->tuple_index()
102 : -1);
103 for (auto operand : instruction->operands()) {
104 hash = tensorflow::Hash64Combine(hash, operand->unique_id());
105 }
106 if (instruction->opcode() == HloOpcode::kConstant) {
107 hash = tensorflow::Hash64Combine(hash, instruction->literal().Hash());
108 }
109 return hash;
110 }
111
112 } // namespace
113
Run(HloModule * module)114 StatusOr<bool> HloCSE::Run(HloModule* module) {
115 bool changed = false;
116 const std::function<bool(const HloInstruction*, const HloInstruction*)>
117 eq_instructions = std::equal_to<const HloInstruction*>();
118 const std::function<bool(const HloComputation*, const HloComputation*)>
119 eq_computations = [](const HloComputation* lhs,
120 const HloComputation* rhs) { return *lhs == *rhs; };
121
122 auto cse_equal = [&](const HloInstruction* lhs, const HloInstruction* rhs) {
123 return lhs->Identical(*rhs, eq_instructions, eq_computations,
124 is_layout_sensitive_);
125 };
126
127 for (auto* computation : module->computations()) {
128 if (only_fusion_computations_ && !computation->IsFusionComputation()) {
129 continue;
130 }
131
132 TF_ASSIGN_OR_RETURN(bool combined,
133 CombineConstants(computation, is_layout_sensitive_));
134 changed |= combined;
135
136 // HLO instructions are grouped into equivalency classes by using the
137 // cse_equal predicate defined above. This set holds a representative
138 // instruction for each class.
139 absl::flat_hash_set<HloInstruction*, decltype(&CseHash),
140 decltype(cse_equal)>
141 representatives(/*N=*/computation->instruction_count() + 1, &CseHash,
142 cse_equal);
143 for (auto instruction : computation->MakeInstructionPostOrder()) {
144 // If the instruction has zero operands (constants, parameters, etc.) skip
145 // over it.
146 if (instruction->operand_count() == 0) {
147 continue;
148 }
149 // Skip instructions which have side effects.
150 if (instruction->HasSideEffect()) {
151 continue;
152 }
153
154 auto it = representatives.find(instruction);
155 if (it != representatives.end()) {
156 HloInstruction* equivalent_instruction = *it;
157 TF_RETURN_IF_ERROR(
158 instruction->ReplaceAllUsesWith(equivalent_instruction));
159 TF_RETURN_IF_ERROR(computation->RemoveInstruction(instruction));
160 changed = true;
161 continue;
162 }
163 representatives.insert(instruction);
164 }
165 }
166 return changed;
167 }
168
169 } // namespace xla
170