1 /* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/common_runtime/inspecting_placer.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 #include <vector>
20 
21 #include "absl/strings/str_join.h"
22 #include "tensorflow/core/common_runtime/colocation_graph.h"
23 #include "tensorflow/core/common_runtime/device.h"
24 #include "tensorflow/core/common_runtime/function_body.h"
25 #include "tensorflow/core/common_runtime/function_def_utils.h"
26 #include "tensorflow/core/common_runtime/placer_inspection_required_ops_utils.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/node_def_util.h"
29 #include "tensorflow/core/framework/types.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 
33 namespace tensorflow {
34 
DebugString() const35 string IOColocationGroups::DebugString() const {
36   std::unordered_map<int, std::vector<string>> group_members;
37   for (int arg_index = 0; arg_index < input_groups.size(); ++arg_index) {
38     int group_id = input_groups[arg_index];
39     group_members[group_id].push_back(strings::StrCat("i:", arg_index));
40   }
41   for (int ret_index = 0; ret_index < output_groups.size(); ++ret_index) {
42     int group_id = output_groups[ret_index];
43     group_members[group_id].push_back(strings::StrCat("o:", ret_index));
44   }
45 
46   std::vector<string> group_strings;
47   for (const auto& it : group_members) {
48     int group_id = it.first;
49     const std::vector<string>& members = it.second;
50     const PossibleDevices& devices = group_devices[group_id];
51     group_strings.push_back(strings::StrCat(
52         "Group(", group_id, " members = [", absl::StrJoin(members, ", "),
53         "] requested_device_name = \"",
54         DeviceNameUtils::ParsedNameToString(devices.requested_device_name),
55         "\" resource_device_name = \"",
56         DeviceNameUtils::ParsedNameToString(devices.resource_device_name),
57         "\" device_types = [",
58         absl::StrJoin(
59             devices.device_types, ", ",
60             [](string* out, const std::pair<DeviceType, int32>& type_and_pref) {
61               out->append(DeviceTypeString(type_and_pref.first));
62             }),
63         "])"));
64   }
65 
66   return absl::StrJoin(group_strings, "\n\t");
67 }
68 
69 // Utility class for constructing IOColocationGroups from a ColocationGraph.
70 class ColocationGraphToIOColocationGroups {
71  public:
72   // colocation_graph is mutable because finding root nodes can update
73   // parent pointers. It is not modified otherwise.
ColocationGraphToIOColocationGroups(ColocationGraph * colocation_graph)74   explicit ColocationGraphToIOColocationGroups(
75       ColocationGraph* colocation_graph)
76       : colocation_graph_(colocation_graph), next_group_id_(0) {}
77 
AssignGroups(const gtl::InlinedVector<Node *,4> & nodes,std::vector<int> * groups)78   void AssignGroups(const gtl::InlinedVector<Node*, 4>& nodes,
79                     std::vector<int>* groups) {
80     for (int i = 0; i < nodes.size(); ++i) {
81       int root_id = colocation_graph_->FindAndUpdateRoot(nodes[i]->id());
82       const auto& it = group_ids_.find(root_id);
83       int assigned_group_id;
84       if (it == group_ids_.end()) {
85         group_ids_[root_id] = next_group_id_;
86         assigned_group_id = next_group_id_;
87         ++next_group_id_;
88       } else {
89         assigned_group_id = it->second;
90       }
91       groups->push_back(assigned_group_id);
92     }
93   }
94 
FillGroups(std::vector<PossibleDevices> * group_devices)95   Status FillGroups(std::vector<PossibleDevices>* group_devices) {
96     group_devices->resize(group_ids_.size());
97     for (const auto& it : group_ids_) {
98       int assigned_group_id = it.second;
99       PossibleDevices& possible_devices = (*group_devices)[assigned_group_id];
100       const Member& member = colocation_graph_->members()[it.first];
101       TF_RETURN_IF_ERROR(member.FillPossibleDevices(&possible_devices));
102     }
103     return Status::OK();
104   }
105 
106  private:
107   ColocationGraph* colocation_graph_;
108   // Allocated group ids: collocation_graph root id -> allocated group id.
109   std::unordered_map<int, int> group_ids_;
110   int next_group_id_;
111 };
112 
InspectingPlacer(const FunctionStack & stack,const FunctionLibraryDefinition * flib_def,const DeviceSet * device_set,const Device * default_device,bool allow_soft_placement,bool log_device_placement)113 InspectingPlacer::InspectingPlacer(const FunctionStack& stack,
114                                    const FunctionLibraryDefinition* flib_def,
115                                    const DeviceSet* device_set,
116                                    const Device* default_device,
117                                    bool allow_soft_placement,
118                                    bool log_device_placement)
119     : stack_(stack),
120       flib_def_(*flib_def),
121       device_set_(*device_set),
122       default_device_(default_device),
123       allow_soft_placement_(allow_soft_placement),
124       log_device_placement_(log_device_placement) {}
125 
ComputeIOColocationGroups(const Node & node,IOColocationGroups * groups)126 Status InspectingPlacer::ComputeIOColocationGroups(const Node& node,
127                                                    IOColocationGroups* groups) {
128   const FunctionDef* fdef;
129   NameAttrList func;
130   TF_RETURN_IF_ERROR(GetFunctionDefAndAttrs(flib_def_, node, &fdef, &func));
131   std::unique_ptr<FunctionBody> fbody;
132 
133   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
134                                              &flib_def_, &fbody));
135 
136   TF_RETURN_IF_ERROR(
137       IsolatePlacerInspectionRequiredOps(flib_def_, fbody->graph));
138   if (stack_.HasFunction(func.name())) {
139     return errors::Unimplemented(
140         "Recursive function calls are not supported. Node ",
141         FormatNodeForError(node), " inside the body of ",
142         errors::FormatFunctionForError(stack_.current_function_name()),
143         " calls function ", errors::FormatFunctionForError(func.name()),
144         " which is already present in the call stack:\n  ",
145         stack_.FormatForError());
146   }
147 
148   ColocationGraph colocation_graph(
149       fbody->graph, stack_.Push(&node, func.name()), &flib_def_, &device_set_,
150       default_device_, allow_soft_placement_, log_device_placement_);
151   TF_RETURN_IF_ERROR(colocation_graph.Initialize());
152 
153   ColocationGraphToIOColocationGroups converter(&colocation_graph);
154   converter.AssignGroups(fbody->arg_nodes, &groups->input_groups);
155   converter.AssignGroups(fbody->ret_nodes, &groups->output_groups);
156   TF_RETURN_IF_ERROR(converter.FillGroups(&groups->group_devices));
157   return Status::OK();
158 }
159 
160 }  // namespace tensorflow
161