1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17 
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/grappler/utils/tpu.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 
31 namespace tensorflow {
32 namespace grappler {
33 namespace internal {
34 
35 // TODO(williamchan): Change this constant to be something smarter, maybe
36 // dynamically determined.
37 constexpr int64 kTensorMaxSize = 64;
38 
39 // All the nodes that should be denylisted and not swapped.
IsDenylisted(const NodeDef & node)40 bool IsDenylisted(const NodeDef& node) {
41   return
42       // Collective ops should not be swapped.
43       IsCollective(node) ||
44       // ControlFlow ops should not be swapped.
45       IsControlFlow(node) ||
46       // NoOp ops should not be swapped (due to group dependencies).
47       IsNoOp(node);
48 }
49 
50 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)51 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
52   if (prop.dtype() == DataType::DT_STRING) {
53     return true;
54   }
55 
56   // Check type to be int32 or int64.
57   if (prop.dtype() != DataType::DT_INT32 &&
58       prop.dtype() != DataType::DT_INT64 &&
59       prop.dtype() != DataType::DT_FLOAT) {
60     return false;
61   }
62 
63   // Check size known and small.
64   const int64 size = NumCoefficients(prop.shape());
65   if (size < 0 || size > kTensorMaxSize) {
66     return false;
67   }
68 
69   return true;
70 }
71 
72 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)73 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
74                         const NodeDef& node, const KernelDef** kdef) {
75   for (const DeviceType& device : devices) {
76     const KernelDef* kernel = nullptr;
77     Status s = FindKernelDef(device, node, &kernel, nullptr);
78     if (s.ok()) {
79       if (kdef) {
80         *kdef = kernel;
81       }
82       return Status::OK();
83     }
84   }
85 
86   return errors::NotFound("Could not find KernelDef for op: ", node.op());
87 }
88 
89 // Checks if a node's output port is host friendly.
90 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)91 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
92                                     GraphProperties* properties,
93                                     const NodeDef& node, int port_id,
94                                     bool* is_candidate) {
95   *is_candidate = false;
96 
97   // Make sure we are not a denylisted op.
98   if (IsDenylisted(node)) {
99     return Status::OK();
100   }
101 
102   // Check to make sure we have the right properties (i.e., statically shaped).
103   if (!properties->has_properties()) {
104     // This is an expensive call, call it lazily.
105     TF_RETURN_IF_ERROR(properties->InferStatically(
106         /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
107         /*include_tensor_values=*/false));
108   }
109   const auto& output_properties = properties->GetOutputProperties(node.name());
110   int output_properties_size = output_properties.size();
111   if (port_id >= output_properties_size) {
112     LOG(WARNING) << "port_id=" << port_id
113                  << " but output_properties.size()=" << output_properties.size()
114                  << "\n"
115                  << node.DebugString();
116     return Status::OK();
117   }
118   if (!IsTensorSmall(output_properties[port_id])) {
119     return Status::OK();
120   }
121 
122   // These nodes may be optimized away downstream (even if pinned to Host), we
123   // should (recursively) check their source.
124   if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
125     for (const auto& fanin : graph.GetFanins(node, false)) {
126       bool fanin_candidate = false;
127       TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
128           graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
129       if (!fanin_candidate) {
130         return Status::OK();
131       }
132     }
133     *is_candidate = true;
134     return Status::OK();
135   }
136 
137   // Check if op's device is on CPU.
138   if (absl::StrContains(node.device(), DEVICE_CPU)) {
139     *is_candidate = true;
140     return Status::OK();
141   }
142 
143   // Check if op's output port is pinned to HostMemory.
144   const OpDef* op = nullptr;
145   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
146   if (!s.ok()) {
147     LOG(WARNING) << "Could not find OpDef for : " << node.op();
148     return Status::OK();
149   }
150 
151   // Map the port_id to output_arg_id.
152   const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
153   if (output_arg_id < 0) {
154     LOG(WARNING) << "Invalid port: " << port_id << "!\n"
155                  << node.DebugString() << "\n"
156                  << op->DebugString();
157     return Status::OK();
158   }
159 
160   // Find the kernel.
161   const KernelDef* kernel = nullptr;
162   s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
163                        &kernel);
164   if (!s.ok()) {
165     LOG(INFO) << "Could not find KernelDef for: " << node.op();
166     return Status::OK();
167   }
168 
169   // Check if the output_arg is pinned to Host.
170   for (const string& host_memory_arg : kernel->host_memory_arg()) {
171     if (op->output_arg(output_arg_id).name() == host_memory_arg) {
172       *is_candidate = true;
173       break;
174     }
175   }
176 
177   return Status::OK();
178 }
179 
180 // Checks if a node's input port is Host friendly.
181 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)182 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
183   // If node is on Host, assume its inputs are Host friendly.
184   if (absl::StrContains(node.device(), DEVICE_CPU)) {
185     return true;
186   }
187 
188   // Check if op's input port is pinned to HostMemory.
189   const OpDef* op = nullptr;
190   Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
191   if (!s.ok()) {
192     LOG(WARNING) << "Could not find OpDef for : " << node.op();
193     return false;
194   }
195   const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
196 
197   // Find the kernel.
198   const KernelDef* kernel = nullptr;
199   s = internal::TryFindKernelDef(
200       {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
201   if (!s.ok()) {
202     LOG(INFO) << "Could not find KernelDef for: " << node.op();
203     return false;
204   }
205 
206   // Check if the input_arg is pinned to Host.
207   for (const string& host_memory_arg : kernel->host_memory_arg()) {
208     if (op->input_arg(input_arg_id).name() == host_memory_arg) {
209       return true;
210     }
211   }
212 
213   return false;
214 }
215 
216 // Checks if a node is a candidate to pin to Host.
217 // The rough algorithm is as follows:
218 // 1] Check if node is denylisted.
219 // 2] Check if node can run on Host.
220 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
221 //    ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)222 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
223                            const NodeDef& node, bool* is_candidate) {
224   *is_candidate = false;
225 
226   // Check if node already on CPU.
227   if (absl::StrContains(node.device(), DEVICE_CPU)) {
228     *is_candidate = true;
229     return Status::OK();
230   }
231 
232   // Skip these node types.
233   if (IsDenylisted(node)) {
234     return Status::OK();
235   }
236 
237   // Check the node can be run on CPU.
238   Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
239   if (!s.ok()) {
240     return Status::OK();
241   }
242 
243   // Check all inputs are Host friendly.
244   for (const GraphView::OutputPort& fanin :
245        graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
246     bool fanin_candidate = false;
247     TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
248         graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
249     if (!fanin_candidate) {
250       return Status::OK();
251     }
252   }
253 
254   // Check all outputs are Host friendly.
255   if (!properties->has_properties()) {
256     // This is an expensive call, call it lazily.
257     TF_RETURN_IF_ERROR(properties->InferStatically(
258         /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
259         /*include_tensor_values=*/false));
260   }
261   for (const auto& prop : properties->GetOutputProperties(node.name())) {
262     if (!IsTensorSmall(prop)) {
263       return Status::OK();
264     }
265   }
266 
267   *is_candidate = true;
268   return Status::OK();
269 }
270 
271 // Tries to find a Host device from `devices`. Returns empty string if no
272 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)273 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
274                          bool has_device_cpu, const string& device) {
275   // Force this node onto the CPU.
276   if (device.empty() && has_device_cpu) {
277     return "/device:CPU:0";
278   } else if (absl::StrContains(device, DEVICE_GPU)) {
279     // Sometimes the cluster can have:
280     //   devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
281     // and we need to handle them properly.
282     for (const auto& device_match :
283          {std::pair<string, string>("GPU", "CPU:0"),
284           std::pair<string, string>("/device", "/device:CPU:0")}) {
285       const string device_host =
286           strings::StrCat(device.substr(0, device.rfind(device_match.first)),
287                           device_match.second);
288       if (devices.find(device_host) != devices.end()) {
289         return device_host;
290       }
291     }
292   }
293 
294   // We couldn't find an appropriate Host device, return no device.
295   return "";
296 }
297 }  // end namespace internal
298 
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)299 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
300                                     GraphDef* optimized_graph) {
301   *optimized_graph = item.graph;
302 
303   // Skip all TPU graphs.
304   if (IsTPUGraphDef(*optimized_graph)) {
305     return Status::OK();
306   }
307 
308   GraphProperties properties(item);
309   GraphView graph(optimized_graph);
310 
311   gtl::FlatSet<string> devices;
312   if (cluster) {
313     const std::vector<string> device_names = cluster->GetDeviceNames();
314     devices.insert(device_names.begin(), device_names.end());
315   } else {
316     devices = {"/device:CPU:0"};
317   }
318 
319   const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
320 
321   // Topologically sort the graph, so that we traverse the nodes in order. This
322   // will help us discover producer->consumer chains of Host ops.
323   TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
324 
325   // All the Const nodes, and their original devices in topological order.
326   std::vector<std::pair<NodeDef*, string>> const_nodes;
327 
328   for (auto& node : *optimized_graph->mutable_node()) {
329     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
330     bool is_candidate = false;
331     TF_RETURN_IF_ERROR(
332         internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
333     if (!is_candidate) {
334       continue;
335     }
336 
337     string device =
338         internal::TryFindHostDevice(devices, has_device_cpu, node.device());
339     if (!device.empty()) {
340       // Keep track of all Const nodes that we swapped.
341       if (IsConstant(node)) {
342         const_nodes.emplace_back(&node, node.device());
343       }
344       VLOG(2) << "Moving node " << node.name() << " to device " << device;
345       *node.mutable_device() = std::move(device);
346     }
347   }
348 
349   // Traverse all `const_nodes`, and map them back to GPU greedily.
350   for (auto& it : const_nodes) {
351     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
352     NodeDef* node = it.first;
353     const string& device = it.second;
354 
355     // Check all the consumers of this node, if any of them are not on CPU, swap
356     // this node back onto the original device.
357     for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
358       // The consumer is not Host friendly, swap it back to the original device.
359       if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
360                                                  fanout.port_id)) {
361         VLOG(2) << "Swapping node " << node->name() << " back to device "
362                 << device;
363         node->set_device(device);
364         break;
365       }
366     }
367   }
368   return Status::OK();
369 }
370 
371 }  // end namespace grappler
372 }  // end namespace tensorflow
373