1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/xla/service/call_inliner.h"
17 
18 #include <deque>
19 
20 #include "tensorflow/compiler/xla/service/call_graph.h"
21 #include "tensorflow/compiler/xla/service/hlo_dce.h"
22 #include "tensorflow/core/lib/core/errors.h"
23 
24 namespace xla {
25 namespace {
26 
27 // Traverses the callee computation, inlining cloned nodes into the caller
28 // computation and connecting them to producers/consumers appropriately.
29 // When the traversal has completed, the provided call instruction is entriely
30 // replaced in the caller's graph.
31 class SubcomputationInsertionVisitor : public DfsHloVisitorWithDefault {
32  public:
33   // call is the call operation -- it will be replaced with the body of the
34   // called computation.
SubcomputationInsertionVisitor(HloInstruction * call)35   explicit SubcomputationInsertionVisitor(HloInstruction* call)
36       : call_(call), outer_(call->parent()) {
37     CHECK_EQ(HloOpcode::kCall, call_->opcode());
38   }
39 
40   // Resolves the operands to the HLO instruction in the inlined (caller) graph,
41   // and clones the HLO instruction into that graph with the new operands.
42   // If the instruction is a call, it is added to the work queue.
DefaultAction(HloInstruction * hlo)43   Status DefaultAction(HloInstruction* hlo) override {
44     TF_RET_CHECK(hlo->opcode() != HloOpcode::kCall);
45     std::vector<HloInstruction*> new_operands;
46     for (HloInstruction* operand : hlo->operands()) {
47       TF_ASSIGN_OR_RETURN(HloInstruction * new_operand, Resolve(operand));
48       new_operands.push_back(new_operand);
49     }
50     VLOG(1) << "Cloning HLO and adding to caller: " << hlo->ToString();
51     auto new_hlo = hlo->CloneWithNewOperands(hlo->shape(), new_operands);
52     HloInstruction* new_hlo_pointer =
53         outer_->AddInstruction(std::move(new_hlo));
54     TF_RETURN_IF_ERROR(NoteMapping(hlo, new_hlo_pointer));
55 
56     // Account for control edges.
57     for (HloInstruction* control_predecessor : hlo->control_predecessors()) {
58       TF_ASSIGN_OR_RETURN(HloInstruction * new_control_predecessor,
59                           Resolve(control_predecessor));
60       TF_RETURN_IF_ERROR(
61           new_control_predecessor->AddControlDependencyTo(new_hlo_pointer));
62     }
63 
64     return Status::OK();
65   }
66 
67   // Does not create new nodes for the parameter; rather, notes the mapping from
68   // the subcomputation parameter node to the call operands in the caller
69   // computation.
HandleParameter(HloInstruction * parameter)70   Status HandleParameter(HloInstruction* parameter) override {
71     TF_RETURN_IF_ERROR(NoteMapping(
72         parameter, call_->mutable_operand(parameter->parameter_number())));
73     return Status::OK();
74   }
75 
76   // Wires the consumers of the call to instead point at the newly created root,
77   // replacing the call operation in the caller computation.
FinishVisit(HloInstruction * root)78   Status FinishVisit(HloInstruction* root) override {
79     TF_ASSIGN_OR_RETURN(HloInstruction * new_root, Resolve(root));
80     VLOG(1) << "Replacing all uses of " << call_->ToString()
81             << " with new root " << new_root->ToString();
82     call_->ClearCalledComputations();
83     return outer_->ReplaceInstruction(call_, new_root);
84   }
85 
ConsumeInstructionMap()86   CallInliner::InlinedInstructionMap ConsumeInstructionMap() {
87     return std::move(subcomputation_hlo_to_new_hlo_);
88   }
89 
90  private:
91   // Resolves the callee subcomputation_hlo to the new (inline) HLO in the
92   // caller computation, or returns a NotFound error if that subcomputation HLO
93   // has not been mapped.
Resolve(HloInstruction * subcomputation_hlo)94   StatusOr<HloInstruction*> Resolve(HloInstruction* subcomputation_hlo) {
95     auto it = subcomputation_hlo_to_new_hlo_.find(subcomputation_hlo);
96     if (it == subcomputation_hlo_to_new_hlo_.end()) {
97       return NotFound(
98           "Could not find mapping from subcomputation HLO %s to a cloned HLO.",
99           subcomputation_hlo->ToString());
100     }
101     return it->second;
102   }
103 
104   // Notes that the given subcomputation_hlo in the callee has been mapped to
105   // the (inline) new_hlo in the caller computation.
106   //
107   // Returns an error status if the subcomputation_hlo is mapped more than
108   // once.
NoteMapping(HloInstruction * subcomputation_hlo,HloInstruction * new_hlo)109   Status NoteMapping(HloInstruction* subcomputation_hlo,
110                      HloInstruction* new_hlo) {
111     auto result = subcomputation_hlo_to_new_hlo_.insert(
112         std::make_pair(subcomputation_hlo, new_hlo));
113     TF_RET_CHECK(result.second)
114         << "A mapping for the subcomputation HLO is already present.";
115     return Status::OK();
116   }
117 
118   HloInstruction* call_;
119   HloComputation* outer_;
120   CallInliner::InlinedInstructionMap subcomputation_hlo_to_new_hlo_;
121 };
122 
123 }  // namespace
124 
Inline(HloInstruction * call)125 /* static */ StatusOr<CallInliner::InlinedInstructionMap> CallInliner::Inline(
126     HloInstruction* call) {
127   TF_RET_CHECK(call->opcode() == HloOpcode::kCall)
128       << "Instruction was not a call op: " << call->opcode();
129   const auto& callees = call->called_computations();
130   TF_RET_CHECK(callees.size() == 1);
131   HloComputation* callee = callees[0];
132   // We visit the callee, cloning its body into its caller.
133   SubcomputationInsertionVisitor visitor(call);
134   TF_RETURN_IF_ERROR(callee->Accept(&visitor));
135   return visitor.ConsumeInstructionMap();
136 }
137 
Run(HloModule * module)138 StatusOr<bool> CallInliner::Run(HloModule* module) {
139   std::unique_ptr<CallGraph> call_graph = CallGraph::Build(module);
140   // Because call graph nodes are visited in post-order (callees before callers)
141   // we'll always inline kCalls into their callers in the appropriate order.
142   bool did_mutate = false;
143   TF_RETURN_IF_ERROR(
144       call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
145         for (const CallSite& callsite : node.caller_callsites()) {
146           VLOG(1) << "Visiting callsite: " << callsite.ToString();
147           if (callsite.instruction()->opcode() == HloOpcode::kCall) {
148             HloInstruction* call = callsite.instruction();
149             TF_RETURN_IF_ERROR(Inline(call).status());
150             did_mutate = true;
151           }
152         }
153         return Status::OK();
154       }));
155   if (did_mutate) {
156     // Run DCE to remove called computations which are now becoming unused.
157     // This can result then in problems if within the called computation, there
158     // were send/recv instructions, which the module group verifier will flag as
159     // error findingthe same channel ID used for multiple send/recv
160     // instructions.
161     TF_RETURN_IF_ERROR(HloDCE().Run(module).status());
162   }
163   return did_mutate;
164 }
165 
166 }  // namespace xla
167