1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/hlo_constant_folding.h"
17 
18 #include <memory>
19 #include <string>
20 #include <utility>
21 #include <vector>
22 
23 #include "absl/memory/memory.h"
24 #include "tensorflow/compiler/xla/layout_util.h"
25 #include "tensorflow/compiler/xla/literal.h"
26 #include "tensorflow/compiler/xla/service/dfs_hlo_visitor_with_default.h"
27 #include "tensorflow/compiler/xla/service/hlo_computation.h"
28 #include "tensorflow/compiler/xla/service/hlo_evaluator.h"
29 #include "tensorflow/compiler/xla/service/hlo_instruction.h"
30 #include "tensorflow/compiler/xla/service/hlo_opcode.h"
31 #include "tensorflow/compiler/xla/service/hlo_query.h"
32 #include "tensorflow/compiler/xla/shape_util.h"
33 #include "tensorflow/compiler/xla/types.h"
34 #include "tensorflow/core/lib/core/errors.h"
35 
36 namespace xla {
37 
38 // Checks whether instr is or transitively contains an instruction that we
39 // shouldn't fold.
40 //
41 // Specifically, we don't fold kRng or kAfterAll instructions:
42 //
43 //  - kRng is already marked as side-effecting and so is skipped elsewhere, but
44 //    we check for it here.  Even kRng weren't side-effecting and took an
45 //    explicit seed, we *still* wouldn't want to constant-fold it, because the
46 //    evaluator's handling of rng is not guaranteed to be identical to any
47 //    particular backend's rng.
48 //
49 //  - kAfterAll needs to be skipped because a kAfterAll op with no args can
50 //    currently materialize a token "out of thin air".  TODO(b/110532604):
51 //    Remove this check once AfterAll requires at least one operand, in which
52 //    case constant folding will be impossible.
IsOrContainsIllegalInstr(const HloInstruction * instr)53 static bool IsOrContainsIllegalInstr(const HloInstruction* instr) {
54   if (instr->opcode() == HloOpcode::kAfterAll ||
55       instr->opcode() == HloOpcode::kRng) {
56     return true;
57   }
58   for (const HloComputation* c : instr->called_computations()) {
59     if (absl::c_any_of(c->instructions(), IsOrContainsIllegalInstr)) {
60       return true;
61     }
62   }
63   return false;
64 }
65 
Run(HloModule * module)66 StatusOr<bool> HloConstantFolding::Run(HloModule* module) {
67   // Limit the constant folding to 0 iterations to skip folding loops. This
68   // retains the behavior from before while loop support in HloEvaluator and may
69   // be revised.
70   auto evaluator = absl::make_unique<HloEvaluator>(/*max_loop_iterations=*/0);
71 
72   XLA_VLOG_LINES(2,
73                  "HloConstantFolding::Run(), before:\n" + module->ToString());
74   bool changed = false;
75 
76   for (auto* computation : module->MakeNonfusionComputations()) {
77     for (auto instruction : computation->MakeInstructionPostOrder()) {
78       // Skip dead code.
79       if (instruction->user_count() == 0 &&
80           computation->root_instruction() != instruction) {
81         continue;
82       }
83 
84       // Skip instructions with non-constant operands.
85       if (!hlo_query::AllOperandsAreConstants(*instruction)) {
86         continue;
87       }
88 
89       // Don't fold Constant, Parameter, and Tuple instructions.  Tuple
90       // constants are not directly supported by any backends, hence folding
91       // Tuple is not useful and would in fact be expanded back into kTuple by
92       // Algebraic Simplifier.
93       //
94       // (We do allow folding subcomputations that contain these instructions.)
95       if (instruction->opcode() == HloOpcode::kParameter ||
96           instruction->opcode() == HloOpcode::kConstant ||
97           instruction->opcode() == HloOpcode::kTuple) {
98         continue;
99       }
100 
101       // Broadcasts dramatically increase the size of constants, which is often
102       // detrimental to performance and memory capacity, so do not fold
103       // broadcasts.
104       if (instruction->opcode() == HloOpcode::kBroadcast ||
105           instruction->opcode() == HloOpcode::kIota) {
106         continue;
107       }
108 
109       // Check for instructions that we can't fold even if they appear inside of
110       // a subcomputation (e.g. a kCall).
111       if (IsOrContainsIllegalInstr(instruction)) {
112         continue;
113       }
114 
115       // Don't constant-fold side-effecting instructions or instructions which
116       // contain side-effecting instructions.
117       if (instruction->HasSideEffect()) {
118         continue;
119       }
120 
121       // Don't constant fold unless it's a net positive or the output is small.
122       if (instruction->shape().IsArray()) {
123         int64 elements_in_removed_operands = 0;
124         for (HloInstruction* operand : instruction->operands()) {
125           if (operand->user_count() == 1 && operand->shape().IsArray()) {
126             elements_in_removed_operands +=
127                 ShapeUtil::ElementsIn(operand->shape());
128           }
129         }
130         int64 elements_in_constant =
131             ShapeUtil::ElementsIn(instruction->shape());
132 
133         static const int64 kMaximumConstantSizeElements = 2 * 1000 * 1000;
134         if (elements_in_constant > elements_in_removed_operands &&
135             elements_in_constant > kMaximumConstantSizeElements) {
136           continue;
137         }
138       }
139 
140       Literal result;
141       // Currently we skip unimplemented operations.
142       // TODO(b/35975797): Fold constant computations for more operations.
143       if (!evaluator->TryEvaluate(instruction, &result)) {
144         VLOG(2) << "Constant folding failed for instruction: "
145                 << instruction->ToString();
146         continue;
147       }
148       VLOG(4) << "Constant folded: " << instruction->ToString();
149 
150       TF_RETURN_IF_ERROR(computation->ReplaceWithNewInstruction(
151           instruction, HloInstruction::CreateConstant(std::move(result))));
152       changed = true;
153     }
154   }
155   XLA_VLOG_LINES(2, "HloConstantFolding::Run(), after:\n" + module->ToString());
156   return changed;
157 }
158 
159 }  // namespace xla
160