1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_query.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35
36 namespace xla {
37
38 // Checks whether instr is or transitively contains an instruction that we
39 // shouldn't fold.
40 //
41 // Specifically, we don't fold kRng or kAfterAll instructions:
42 //
43 // - kRng is already marked as side-effecting and so is skipped elsewhere, but
44 // we check for it here. Even kRng weren't side-effecting and took an
45 // explicit seed, we *still* wouldn't want to constant-fold it, because the
46 // evaluator's handling of rng is not guaranteed to be identical to any
47 // particular backend's rng.
48 //
49 // - kAfterAll needs to be skipped because a kAfterAll op with no args can
50 // currently materialize a token "out of thin air". TODO(b/110532604):
51 // Remove this check once AfterAll requires at least one operand, in which
52 // case constant folding will be impossible.
IsOrContainsIllegalInstr(const HloInstruction * instr)53 static bool IsOrContainsIllegalInstr(const HloInstruction* instr) {
54 if (instr->opcode() == HloOpcode::kAfterAll ||
55 instr->opcode() == HloOpcode::kRng) {
56 return true;
57 }
58 for (const HloComputation* c : instr->called_computations()) {
59 if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) {
60 return true;
61 }
62 }
63 return false;
64 }
65
Run(HloModule * module)66 StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
67 // Limit the constant folding to 0 iterations to skip folding loops. This
68 // retains the behavior from before while loop support in HloEvaluator and may
69 // be revised.
70 auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
71
72 XLA_VLOG_LINES(2,
73 "HloConstantFolding::Run(), before:\n" + module->ToString());
74 bool changed = false;
75
76 for (auto* computation : module->MakeNonfusionComputations()) {
77 for (auto instruction : computation->MakeInstructionPostOrder()) {
78 // Skip dead code.
79 if (instruction->user_count() == 0 &&
80 computation->root_instruction() != instruction) {
81 continue;
82 }
83
84 // Skip instructions with non-constant operands.
85 if (!hlo_query::AllOperandsAreConstants(*instruction)) {
86 continue;
87 }
88
89 // Don't fold Constant, Parameter, and Tuple instructions. Tuple
90 // constants are not directly supported by any backends, hence folding
91 // Tuple is not useful and would in fact be expanded back into kTuple by
92 // Algebraic Simplifier.
93 //
94 // (We do allow folding subcomputations that contain these instructions.)
95 if (instruction->opcode() == HloOpcode::kParameter ||
96 instruction->opcode() == HloOpcode::kConstant ||
97 instruction->opcode() == HloOpcode::kTuple) {
98 continue;
99 }
100
101 // Broadcasts dramatically increase the size of constants, which is often
102 // detrimental to performance and memory capacity, so do not fold
103 // broadcasts.
104 if (instruction->opcode() == HloOpcode::kBroadcast ||
105 instruction->opcode() == HloOpcode::kIota) {
106 continue;
107 }
108
109 // Check for instructions that we can't fold even if they appear inside of
110 // a subcomputation (e.g. a kCall).
111 if (IsOrContainsIllegalInstr(instruction)) {
112 continue;
113 }
114
115 // Don't constant-fold side-effecting instructions or instructions which
116 // contain side-effecting instructions.
117 if (instruction->HasSideEffect()) {
118 continue;
119 }
120
121 // Don't constant fold unless it's a net positive or the output is small.
122 if (instruction->shape().IsArray()) {
123 int64 elements_in_removed_operands = 0;
124 for (HloInstruction* operand : instruction->operands()) {
125 if (operand->user_count() == 1 && operand->shape().IsArray()) {
126 elements_in_removed_operands +=
127 ShapeUtil::ElementsIn(operand->shape());
128 }
129 }
130 int64 elements_in_constant =
131 ShapeUtil::ElementsIn(instruction->shape());
132
133 static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
134 if (elements_in_constant > elements_in_removed_operands &&
135 elements_in_constant > kMaximumConstantSizeElements) {
136 continue;
137 }
138 }
139
140 Literal result;
141 // Currently we skip unimplemented operations.
142 // TODO(b/35975797): Fold constant computations for more operations.
143 if (!evaluator->TryEvaluate(instruction, &result)) {
144 VLOG(2) << "Constant folding failed for instruction: "
145 << instruction->ToString();
146 continue;
147 }
148 VLOG(4) << "Constant folded: " << instruction->ToString();
149
150 TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
151 instruction, HloInstruction::CreateConstant(std::move(result))));
152 changed = true;
153 }
154 }
155 XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
156 return changed;
157 }
158
159 } // namespace xla
160