1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/xla/service/call_inliner.h"
17
18 #include <deque>
19
20 #include "tensorflow/compiler/xla/service/call_graph.h"
21 #include "tensorflow/compiler/xla/service/hlo_dce.h"
22 #include "tensorflow/core/lib/core/errors.h"
23
24 namespace xla {
25 namespace {
26
27 // Traverses the callee computation, inlining cloned nodes into the caller
28 // computation and connecting them to producers/consumers appropriately.
29 // When the traversal has completed, the provided call instruction is entriely
30 // replaced in the caller's graph.
31 class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
32 public:
33 // call is the call operation -- it will be replaced with the body of the
34 // called computation.
SubcomputationInsertionVisitor(HloInstruction * call)35 explicit SubcomputationInsertionVisitor(HloInstruction* call)
36 : call_(call), outer_(call->parent()) {
37 CHECK_EQ(HloOpcode::kCall, call_->opcode());
38 }
39
40 // Resolves the operands to the HLO instruction in the inlined (caller) graph,
41 // and clones the HLO instruction into that graph with the new operands.
42 // If the instruction is a call, it is added to the work queue.
DefaultAction(HloInstruction * hlo)43 Status DefaultAction(HloInstruction* hlo) override {
44 TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
45 std::vector<HloInstruction*> new_operands;
46 for (HloInstruction* operand : hlo->operands()) {
47 TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
48 new_operands.push_back(new_operand);
49 }
50 VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString();
51 auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands);
52 HloInstruction* new_hlo_pointer =
53 outer_->AddInstruction(std::move(new_hlo));
54 TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer));
55
56 // Account for control edges.
57 for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
58 TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor,
59 Resolve(control_predecessor));
60 TF_RETURN_IF_ERROR(
61 new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
62 }
63
64 return Status::OK();
65 }
66
67 // Does not create new nodes for the parameter; rather, notes the mapping from
68 // the subcomputation parameter node to the call operands in the caller
69 // computation.
HandleParameter(HloInstruction * parameter)70 Status HandleParameter(HloInstruction* parameter) override {
71 TF_RETURN_IF_ERROR(NoteMapping(
72 parameter, call_->mutable_operand(parameter->parameter_number())));
73 return Status::OK();
74 }
75
76 // Wires the consumers of the call to instead point at the newly created root,
77 // replacing the call operation in the caller computation.
FinishVisit(HloInstruction * root)78 Status FinishVisit(HloInstruction* root) override {
79 TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root));
80 VLOG(1) << "Replacing all uses of " << call_->ToString()
81 << " with new root " << new_root->ToString();
82 call_->ClearCalledComputations();
83 return outer_->ReplaceInstruction(call_, new_root);
84 }
85
ConsumeInstructionMap()86 CallInliner::InlinedInstructionMap ConsumeInstructionMap() {
87 return std::move(subcomputation_hlo_to_new_hlo_);
88 }
89
90 private:
91 // Resolves the callee subcomputation_hlo to the new (inline) HLO in the
92 // caller computation, or returns a NotFound error if that subcomputation HLO
93 // has not been mapped.
Resolve(HloInstruction * subcomputation_hlo)94 StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) {
95 auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo);
96 if (it == subcomputation_hlo_to_new_hlo_.end()) {
97 return NotFound(
98 "Could not find mapping from subcomputation HLO %s to a cloned HLO.",
99 subcomputation_hlo->ToString());
100 }
101 return it->second;
102 }
103
104 // Notes that the given subcomputation_hlo in the callee has been mapped to
105 // the (inline) new_hlo in the caller computation.
106 //
107 // Returns an error status if the subcomputation_hlo is mapped more than
108 // once.
NoteMapping(HloInstruction * subcomputation_hlo,HloInstruction * new_hlo)109 Status NoteMapping(HloInstruction* subcomputation_hlo,
110 HloInstruction* new_hlo) {
111 auto result = subcomputation_hlo_to_new_hlo_.insert(
112 std::make_pair(subcomputation_hlo, new_hlo));
113 TF_RET_CHECK(result.second)
114 << "A mapping for the subcomputation HLO is already present.";
115 return Status::OK();
116 }
117
118 HloInstruction* call_;
119 HloComputation* outer_;
120 CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
121 };
122
123 } // namespace
124
Inline(HloInstruction * call)125 /* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline(
126 HloInstruction* call) {
127 TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
128 << "Instruction was not a call op: " << call->opcode();
129 const auto& callees = call->called_computations();
130 TF_RET_CHECK(callees.size() == 1);
131 HloComputation* callee = callees[0];
132 // We visit the callee, cloning its body into its caller.
133 SubcomputationInsertionVisitor visitor(call);
134 TF_RETURN_IF_ERROR(callee->Accept(&visitor));
135 return visitor.ConsumeInstructionMap();
136 }
137
Run(HloModule * module)138 StatusOr<bool> CallInliner::Run(HloModule* module) {
139 std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
140 // Because call graph nodes are visited in post-order (callees before callers)
141 // we'll always inline kCalls into their callers in the appropriate order.
142 bool did_mutate = false;
143 TF_RETURN_IF_ERROR(
144 call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
145 for (const CallSite& callsite : node.caller_callsites()) {
146 VLOG(1) << "Visiting callsite: " << callsite.ToString();
147 if (callsite.instruction()->opcode() == HloOpcode::kCall) {
148 HloInstruction* call = callsite.instruction();
149 TF_RETURN_IF_ERROR(Inline(call).status());
150 did_mutate = true;
151 }
152 }
153 return Status::OK();
154 }));
155 if (did_mutate) {
156 // Run DCE to remove called computations which are now becoming unused.
157 // This can result then in problems if within the called computation, there
158 // were send/recv instructions, which the module group verifier will flag as
159 // error findingthe same channel ID used for multiple send/recv
160 // instructions.
161 TF_RETURN_IF_ERROR(HloDCE().Run(module).status());
162 }
163 return did_mutate;
164 }
165
166 } // namespace xla
167