1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
17
18 #include "tensorflow/core/framework/op.h"
19 #include "tensorflow/core/framework/tensor_shape.pb.h"
20 #include "tensorflow/core/framework/types.h"
21 #include "tensorflow/core/grappler/graph_view.h"
22 #include "tensorflow/core/grappler/grappler_item.h"
23 #include "tensorflow/core/grappler/op_types.h"
24 #include "tensorflow/core/grappler/utils/symbolic_shapes.h"
25 #include "tensorflow/core/grappler/utils/topological_sort.h"
26 #include "tensorflow/core/grappler/utils/tpu.h"
27 #include "tensorflow/core/lib/core/errors.h"
28 #include "tensorflow/core/lib/strings/str_util.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30
31 namespace tensorflow {
32 namespace grappler {
33 namespace internal {
34
35 // TODO(williamchan): Change this constant to be something smarter, maybe
36 // dynamically determined.
37 constexpr int64 kTensorMaxSize = 64;
38
39 // All the nodes that should be denylisted and not swapped.
IsDenylisted(const NodeDef & node)40 bool IsDenylisted(const NodeDef& node) {
41 return
42 // Collective ops should not be swapped.
43 IsCollective(node) ||
44 // ControlFlow ops should not be swapped.
45 IsControlFlow(node) ||
46 // NoOp ops should not be swapped (due to group dependencies).
47 IsNoOp(node);
48 }
49
50 // Check if Tensor is either a string or is integer and small size
IsTensorSmall(const OpInfo::TensorProperties & prop)51 bool IsTensorSmall(const OpInfo::TensorProperties& prop) {
52 if (prop.dtype() == DataType::DT_STRING) {
53 return true;
54 }
55
56 // Check type to be int32 or int64.
57 if (prop.dtype() != DataType::DT_INT32 &&
58 prop.dtype() != DataType::DT_INT64 &&
59 prop.dtype() != DataType::DT_FLOAT) {
60 return false;
61 }
62
63 // Check size known and small.
64 const int64 size = NumCoefficients(prop.shape());
65 if (size < 0 || size > kTensorMaxSize) {
66 return false;
67 }
68
69 return true;
70 }
71
72 // Find KernelDef for `node`, greedily return first found from `devices`.
TryFindKernelDef(const std::vector<DeviceType> & devices,const NodeDef & node,const KernelDef ** kdef)73 Status TryFindKernelDef(const std::vector<DeviceType>& devices,
74 const NodeDef& node, const KernelDef** kdef) {
75 for (const DeviceType& device : devices) {
76 const KernelDef* kernel = nullptr;
77 Status s = FindKernelDef(device, node, &kernel, nullptr);
78 if (s.ok()) {
79 if (kdef) {
80 *kdef = kernel;
81 }
82 return Status::OK();
83 }
84 }
85
86 return errors::NotFound("Could not find KernelDef for op: ", node.op());
87 }
88
89 // Checks if a node's output port is host friendly.
90 // Roughly this means checking if the output port is on Host memory.
IsNodeOutputPortHostFriendly(const GraphView & graph,GraphProperties * properties,const NodeDef & node,int port_id,bool * is_candidate)91 Status IsNodeOutputPortHostFriendly(const GraphView& graph,
92 GraphProperties* properties,
93 const NodeDef& node, int port_id,
94 bool* is_candidate) {
95 *is_candidate = false;
96
97 // Make sure we are not a denylisted op.
98 if (IsDenylisted(node)) {
99 return Status::OK();
100 }
101
102 // Check to make sure we have the right properties (i.e., statically shaped).
103 if (!properties->has_properties()) {
104 // This is an expensive call, call it lazily.
105 TF_RETURN_IF_ERROR(properties->InferStatically(
106 /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
107 /*include_tensor_values=*/false));
108 }
109 const auto& output_properties = properties->GetOutputProperties(node.name());
110 int output_properties_size = output_properties.size();
111 if (port_id >= output_properties_size) {
112 LOG(WARNING) << "port_id=" << port_id
113 << " but output_properties.size()=" << output_properties.size()
114 << "\n"
115 << node.DebugString();
116 return Status::OK();
117 }
118 if (!IsTensorSmall(output_properties[port_id])) {
119 return Status::OK();
120 }
121
122 // These nodes may be optimized away downstream (even if pinned to Host), we
123 // should (recursively) check their source.
124 if (IsIdentity(node) || IsIdentityNSingleInput(node)) {
125 for (const auto& fanin : graph.GetFanins(node, false)) {
126 bool fanin_candidate = false;
127 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
128 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
129 if (!fanin_candidate) {
130 return Status::OK();
131 }
132 }
133 *is_candidate = true;
134 return Status::OK();
135 }
136
137 // Check if op's device is on CPU.
138 if (absl::StrContains(node.device(), DEVICE_CPU)) {
139 *is_candidate = true;
140 return Status::OK();
141 }
142
143 // Check if op's output port is pinned to HostMemory.
144 const OpDef* op = nullptr;
145 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
146 if (!s.ok()) {
147 LOG(WARNING) << "Could not find OpDef for : " << node.op();
148 return Status::OK();
149 }
150
151 // Map the port_id to output_arg_id.
152 const int output_arg_id = OpOutputPortIdToArgId(node, *op, port_id);
153 if (output_arg_id < 0) {
154 LOG(WARNING) << "Invalid port: " << port_id << "!\n"
155 << node.DebugString() << "\n"
156 << op->DebugString();
157 return Status::OK();
158 }
159
160 // Find the kernel.
161 const KernelDef* kernel = nullptr;
162 s = TryFindKernelDef({node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node,
163 &kernel);
164 if (!s.ok()) {
165 LOG(INFO) << "Could not find KernelDef for: " << node.op();
166 return Status::OK();
167 }
168
169 // Check if the output_arg is pinned to Host.
170 for (const string& host_memory_arg : kernel->host_memory_arg()) {
171 if (op->output_arg(output_arg_id).name() == host_memory_arg) {
172 *is_candidate = true;
173 break;
174 }
175 }
176
177 return Status::OK();
178 }
179
180 // Checks if a node's input port is Host friendly.
181 // Roughly this means checking if the input port is on Host memory.
IsNodeInputPortHostFriendly(const NodeDef & node,int port_id)182 bool IsNodeInputPortHostFriendly(const NodeDef& node, int port_id) {
183 // If node is on Host, assume its inputs are Host friendly.
184 if (absl::StrContains(node.device(), DEVICE_CPU)) {
185 return true;
186 }
187
188 // Check if op's input port is pinned to HostMemory.
189 const OpDef* op = nullptr;
190 Status s = OpRegistry::Global()->LookUpOpDef(node.op(), &op);
191 if (!s.ok()) {
192 LOG(WARNING) << "Could not find OpDef for : " << node.op();
193 return false;
194 }
195 const int input_arg_id = OpInputPortIdToArgId(node, *op, port_id);
196
197 // Find the kernel.
198 const KernelDef* kernel = nullptr;
199 s = internal::TryFindKernelDef(
200 {node.device().c_str(), DEVICE_GPU, DEVICE_CPU}, node, &kernel);
201 if (!s.ok()) {
202 LOG(INFO) << "Could not find KernelDef for: " << node.op();
203 return false;
204 }
205
206 // Check if the input_arg is pinned to Host.
207 for (const string& host_memory_arg : kernel->host_memory_arg()) {
208 if (op->input_arg(input_arg_id).name() == host_memory_arg) {
209 return true;
210 }
211 }
212
213 return false;
214 }
215
216 // Checks if a node is a candidate to pin to Host.
217 // The rough algorithm is as follows:
218 // 1] Check if node is denylisted.
219 // 2] Check if node can run on Host.
220 // 3] Check all input/outputs are Host "friendly" (atm, friendly means small,
221 // ints, and pinned to Host).
IsNodeHostCandidate(const GraphView & graph,GraphProperties * properties,const NodeDef & node,bool * is_candidate)222 Status IsNodeHostCandidate(const GraphView& graph, GraphProperties* properties,
223 const NodeDef& node, bool* is_candidate) {
224 *is_candidate = false;
225
226 // Check if node already on CPU.
227 if (absl::StrContains(node.device(), DEVICE_CPU)) {
228 *is_candidate = true;
229 return Status::OK();
230 }
231
232 // Skip these node types.
233 if (IsDenylisted(node)) {
234 return Status::OK();
235 }
236
237 // Check the node can be run on CPU.
238 Status s = TryFindKernelDef({DEVICE_CPU}, node, nullptr);
239 if (!s.ok()) {
240 return Status::OK();
241 }
242
243 // Check all inputs are Host friendly.
244 for (const GraphView::OutputPort& fanin :
245 graph.GetFanins(node, /*include_controlling_nodes=*/false)) {
246 bool fanin_candidate = false;
247 TF_RETURN_IF_ERROR(IsNodeOutputPortHostFriendly(
248 graph, properties, *fanin.node, fanin.port_id, &fanin_candidate));
249 if (!fanin_candidate) {
250 return Status::OK();
251 }
252 }
253
254 // Check all outputs are Host friendly.
255 if (!properties->has_properties()) {
256 // This is an expensive call, call it lazily.
257 TF_RETURN_IF_ERROR(properties->InferStatically(
258 /*assume_valid_feeds=*/false, /*aggressive_shape_inference=*/false,
259 /*include_tensor_values=*/false));
260 }
261 for (const auto& prop : properties->GetOutputProperties(node.name())) {
262 if (!IsTensorSmall(prop)) {
263 return Status::OK();
264 }
265 }
266
267 *is_candidate = true;
268 return Status::OK();
269 }
270
271 // Tries to find a Host device from `devices`. Returns empty string if no
272 // matching Host device is found.
TryFindHostDevice(const gtl::FlatSet<string> & devices,bool has_device_cpu,const string & device)273 string TryFindHostDevice(const gtl::FlatSet<string>& devices,
274 bool has_device_cpu, const string& device) {
275 // Force this node onto the CPU.
276 if (device.empty() && has_device_cpu) {
277 return "/device:CPU:0";
278 } else if (absl::StrContains(device, DEVICE_GPU)) {
279 // Sometimes the cluster can have:
280 // devices = {"/device:CPU:0", "/device:XLA_GPU:0"}
281 // and we need to handle them properly.
282 for (const auto& device_match :
283 {std::pair<string, string>("GPU", "CPU:0"),
284 std::pair<string, string>("/device", "/device:CPU:0")}) {
285 const string device_host =
286 strings::StrCat(device.substr(0, device.rfind(device_match.first)),
287 device_match.second);
288 if (devices.find(device_host) != devices.end()) {
289 return device_host;
290 }
291 }
292 }
293
294 // We couldn't find an appropriate Host device, return no device.
295 return "";
296 }
297 } // end namespace internal
298
Optimize(Cluster * cluster,const GrapplerItem & item,GraphDef * optimized_graph)299 Status PinToHostOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
300 GraphDef* optimized_graph) {
301 *optimized_graph = item.graph;
302
303 // Skip all TPU graphs.
304 if (IsTPUGraphDef(*optimized_graph)) {
305 return Status::OK();
306 }
307
308 GraphProperties properties(item);
309 GraphView graph(optimized_graph);
310
311 gtl::FlatSet<string> devices;
312 if (cluster) {
313 const std::vector<string> device_names = cluster->GetDeviceNames();
314 devices.insert(device_names.begin(), device_names.end());
315 } else {
316 devices = {"/device:CPU:0"};
317 }
318
319 const bool has_device_cpu = devices.find("/device:CPU:0") != devices.end();
320
321 // Topologically sort the graph, so that we traverse the nodes in order. This
322 // will help us discover producer->consumer chains of Host ops.
323 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
324
325 // All the Const nodes, and their original devices in topological order.
326 std::vector<std::pair<NodeDef*, string>> const_nodes;
327
328 for (auto& node : *optimized_graph->mutable_node()) {
329 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
330 bool is_candidate = false;
331 TF_RETURN_IF_ERROR(
332 internal::IsNodeHostCandidate(graph, &properties, node, &is_candidate));
333 if (!is_candidate) {
334 continue;
335 }
336
337 string device =
338 internal::TryFindHostDevice(devices, has_device_cpu, node.device());
339 if (!device.empty()) {
340 // Keep track of all Const nodes that we swapped.
341 if (IsConstant(node)) {
342 const_nodes.emplace_back(&node, node.device());
343 }
344 VLOG(2) << "Moving node " << node.name() << " to device " << device;
345 *node.mutable_device() = std::move(device);
346 }
347 }
348
349 // Traverse all `const_nodes`, and map them back to GPU greedily.
350 for (auto& it : const_nodes) {
351 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
352 NodeDef* node = it.first;
353 const string& device = it.second;
354
355 // Check all the consumers of this node, if any of them are not on CPU, swap
356 // this node back onto the original device.
357 for (const GraphView::InputPort& fanout : graph.GetFanouts(*node, false)) {
358 // The consumer is not Host friendly, swap it back to the original device.
359 if (!internal::IsNodeInputPortHostFriendly(*fanout.node,
360 fanout.port_id)) {
361 VLOG(2) << "Swapping node " << node->name() << " back to device "
362 << device;
363 node->set_device(device);
364 break;
365 }
366 }
367 }
368 return Status::OK();
369 }
370
371 } // end namespace grappler
372 } // end namespace tensorflow
373