1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/device_util.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/public/version.h"
33
34 namespace tensorflow {
35 namespace {
36
NotBackedge(const Edge & edge)37 bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
38
39 namespace reduce_device_to_host_copies {
FindNodesToDecluster(const Graph & graph,absl::flat_hash_set<Node * > * result,absl::Span<Node * const> post_order)40 Status FindNodesToDecluster(const Graph& graph,
41 absl::flat_hash_set<Node*>* result,
42 absl::Span<Node* const> post_order) {
43 // Find nodes that have at least one user outside their cluster that expects
44 // hostmem output. These nodes should be cloned to outside the cluster to
45 // avoid the device-host copy we'd otherwise need.
46
47 MemoryTypeVector input_mtypes, output_mtypes;
48
49 for (Node* n : post_order) {
50 absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
51 if (!from_cluster) {
52 continue;
53 }
54
55 // Assume the benefit of not outputting a larger tensor outweighs the
56 // benefit of this check.
57 // TODO(tpopp): Only apply this if the value being consumed is not output
58 // from the cluster to another consumer.
59 // TODO(tpopp): See if XlaRun can be modified to avoid this issue
60 // completely.
61 if (IsShapeConsumerOp(*n)) {
62 continue;
63 }
64 // We assume the only XLA-auto-clusterable operations with side effects are
65 // resource variable updates. We can't execute these twice.
66 if (HasResourceInputOrOutput(*n)) {
67 continue;
68 }
69
70 DeviceType device_type("");
71 TF_RETURN_IF_ERROR(
72 DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
73 TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
74 n->def(), &input_mtypes,
75 &output_mtypes));
76 for (const Edge* e : n->out_edges()) {
77 Node* dst = e->dst();
78
79 if (e->IsControlEdge()) {
80 continue;
81 }
82
83 bool edge_incurs_extra_device_to_host_copy;
84 if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
85 // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
86 // keep the node clustered -- XLA will also produce the output in device
87 // memory and we will get some benefit from clustering.
88 edge_incurs_extra_device_to_host_copy = false;
89 } else {
90 MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
91 DeviceType dst_device_type("");
92 TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
93 &dst_device_type));
94 TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
95 dst->def(), &dst_input_mtypes,
96 &dst_output_mtypes));
97 edge_incurs_extra_device_to_host_copy =
98 dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
99 }
100
101 if (!edge_incurs_extra_device_to_host_copy) {
102 continue;
103 }
104
105 // Check if `dst` is in a different cluster, unclustered, or about to be
106 // partially declustered (here we rely on the post-order traversal order).
107 // If yes, decluster `n` to avoid the device-to-host memcpy.
108 absl::optional<absl::string_view> dst_cluster =
109 result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst);
110 if (from_cluster != dst_cluster) {
111 CHECK(result->insert(n).second);
112 break;
113 }
114 }
115 }
116 return Status::OK();
117 }
118
PartiallyDeclusterNode(Graph * graph,Node * n)119 Status PartiallyDeclusterNode(Graph* graph, Node* n) {
120 absl::string_view cluster_name = *GetXlaClusterForNode(*n);
121 absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
122 for (const Edge* out_edge : n->out_edges()) {
123 if (out_edge->IsControlEdge()) {
124 continue;
125 }
126
127 Node* dst = out_edge->dst();
128 absl::optional<absl::string_view> dst_cluster_name =
129 GetXlaClusterForNode(*dst);
130 if (dst_cluster_name != cluster_name) {
131 out_edges_to_clone.push_back(out_edge);
132 }
133 }
134
135 CHECK(!out_edges_to_clone.empty()) << n->DebugString();
136
137 NodeDef ndef = n->def();
138 ndef.set_name(absl::StrCat(n->name(), "/declustered"));
139 MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
140 RemoveFromXlaCluster(&ndef);
141 Status s;
142 Node* cloned_node = graph->AddNode(ndef, &s);
143 cloned_node->set_assigned_device_name(n->assigned_device_name());
144 TF_RETURN_IF_ERROR(s);
145
146 for (const Edge* in_edge : n->in_edges()) {
147 graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
148 in_edge->dst_input());
149 }
150
151 for (const Edge* out_edge_to_clone : out_edges_to_clone) {
152 graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
153 out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
154 graph->RemoveEdge(out_edge_to_clone);
155 }
156
157 if (n->out_edges().empty()) {
158 graph->RemoveNode(n);
159 }
160
161 return Status::OK();
162 }
163
164 // Clones nodes to outside their cluster to avoid device-to-host copies. For
165 // instance, converts this:
166 //
167 // .....
168 // |
169 // v
170 // A_Clustered ====> C_Unclustered
171 // |
172 // v
173 // B_Clustered
174 //
175 // to:
176 //
177 // .....
178 // | |
179 // | +-------------+
180 // | |
181 // v v
182 // A_Clustered A_Unclustered ====> C_Unclustered
183 // |
184 // v
185 // B_Clustered
186 //
187 // where the ===> arrow has a hostmem source and destination and would entail a
188 // device to host copy if the source and destination were not in the same XLA
189 // cluster.
PartiallyDeclusterGraph(Graph * graph)190 Status PartiallyDeclusterGraph(Graph* graph) {
191 // When deciding whether to decluster a particular node, we base our decision
192 // on if we've decided that some of its consumers have to be declustered too.
193 // Iterating the graph in post-order guarantees that consumers have been
194 // visited before producers.
195 std::vector<Node*> post_order;
196 GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
197 /*edge_filter=*/NotBackedge);
198
199 absl::flat_hash_set<Node*> nodes_to_partially_decluster;
200 TF_RETURN_IF_ERROR(
201 FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
202
203 if (VLOG_IS_ON(3)) {
204 for (Node* n : post_order) {
205 if (nodes_to_partially_decluster.count(n)) {
206 VLOG(3) << n->DebugString();
207 }
208 }
209 }
210
211 for (Node* n : post_order) {
212 if (nodes_to_partially_decluster.count(n)) {
213 TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
214 }
215 }
216
217 // Recompute post order since PartiallyDeclusterNode may have deleted nodes.
218 post_order.clear();
219 GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
220 /*edge_filter=*/NotBackedge);
221 nodes_to_partially_decluster.clear();
222 TF_RETURN_IF_ERROR(
223 FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
224 CHECK(nodes_to_partially_decluster.empty());
225
226 return Status::OK();
227 }
228 } // namespace reduce_device_to_host_copies
229
230 namespace reduce_recompilation {
IsIntraClusterEdge(const Edge & edge)231 bool IsIntraClusterEdge(const Edge& edge) {
232 absl::optional<absl::string_view> src_cluster_name =
233 GetXlaClusterForNode(*edge.src());
234 absl::optional<absl::string_view> dst_cluster_name =
235 GetXlaClusterForNode(*edge.dst());
236 return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
237 }
238
IsMustCompileDevice(const DeviceType & device_type)239 bool IsMustCompileDevice(const DeviceType& device_type) {
240 const XlaOpRegistry::DeviceRegistration* registration;
241 if (XlaOpRegistry::GetCompilationDevice(device_type.type(), ®istration)) {
242 return registration->autoclustering_policy ==
243 XlaOpRegistry::AutoclusteringPolicy::kAlways;
244 }
245
246 return false;
247 }
248
MustCompileNode(const Node * n,bool * must_compile)249 Status MustCompileNode(const Node* n, bool* must_compile) {
250 DeviceType device_type("");
251 TF_RETURN_IF_ERROR(
252 DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
253
254 if (IsMustCompileDevice(device_type)) {
255 *must_compile = true;
256 return Status::OK();
257 }
258
259 // We must compile `n` if it does not have a TensorFlow kernel.
260 *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok();
261 return Status::OK();
262 }
263
264 // Declusters nodes to reduce the number of times we think we need to recompile
265 // a TensorFlow graph.
266 //
267 // Abstractly, if we have a cluster of this form:
268 //
269 // x0 = arg0
270 // x1 = arg1
271 // ...
272 // shape = f(x0, x1, ...)
273 // result = Reshape(input=<something>, new_shape=shape)
274 //
275 // then pulling `f` out of the cluster may reduce the number of compilations and
276 // will never increase the number of compilations.
277 //
278 // We may reduce the number of compilations if f is many to one. For instance
279 // if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
280 // compilations if f is in the cluster but only one compilation if f is outside
281 // the cluster.
282 //
283 // Declustering f will increase the number of compilations only if f is a
284 // one-to-many "function" i.e. isn't a function at all. RNG is one possible
285 // example, depending on how we look at it. But we never create clusters where
286 // such f's would be marked as must-be-constant.
287 //
288 // We assume here that the extra repeated (repeated compared to a clustered f
289 // where it will always be constant folded) host-side computation of f does not
290 // regress performance in any significant manner. We will have to revisit this
291 // algorithm with a more complex cost model if this assumption turns out to be
292 // incorrect.
PartiallyDeclusterGraph(Graph * graph,const FunctionLibraryDefinition * flib_def,Env * env)293 Status PartiallyDeclusterGraph(Graph* graph,
294 const FunctionLibraryDefinition* flib_def,
295 Env* env) {
296 std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
297 OptimizerOptions opts;
298 auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
299 nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
300 FunctionLibraryRuntime* lib_runtime =
301 pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
302 TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,
303 &compile_time_const_nodes,
304 lib_runtime, IsIntraClusterEdge));
305
306 std::vector<Node*> rpo;
307 GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
308 /*edge_filter=*/NotBackedge);
309 for (Node* n : rpo) {
310 if (!compile_time_const_nodes[n->id()]) {
311 continue;
312 }
313
314 absl::string_view cluster_name = *GetXlaClusterForNode(*n);
315 bool node_on_cluster_edge =
316 absl::c_all_of(n->in_edges(), [&](const Edge* e) {
317 absl::optional<absl::string_view> incoming_cluster =
318 GetXlaClusterForNode(*e->src());
319 return !incoming_cluster || *incoming_cluster != cluster_name;
320 });
321
322 // We don't want to decluster F in a graph like
323 //
324 // Input -> OP -> Shape -> F -> Reshape
325 //
326 // Doing so will break up the cluster. Even if we were okay with breaking
327 // up the cluster we will at least have to relabel the two clusters to have
328 // different cluster names.
329 //
330 // We may want to revisit this in the future: we may have cases where OP is
331 // a small computation that does not benefit from XLA while XLA can optimize
332 // everything that follows the Reshape. In these cases it may be wise to
333 // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
334 // function.
335 //
336 // Note that we do do the right thing for graphs like:
337 //
338 // Input -> F0 -> F1 -> Reshape
339 //
340 // Since we iterate in RPO, we'll first encounter F0, decluster it, then
341 // encounter F1, decluster it and so on.
342 if (node_on_cluster_edge) {
343 bool must_compile_node;
344 TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
345 if (!must_compile_node) {
346 VLOG(3) << "Declustering must-be-constant node " << n->name();
347 RemoveFromXlaCluster(n);
348 }
349 }
350 }
351
352 return Status::OK();
353 }
354 } // namespace reduce_recompilation
355
356 namespace decluster_root_shape_consumers {
357
PartiallyDeclusterGraph(Graph * graph)358 Status PartiallyDeclusterGraph(Graph* graph) {
359 std::vector<Node*> reverse_post_order;
360 GetReversePostOrder(*graph, &reverse_post_order,
361 /*stable_comparator=*/NodeComparatorName(),
362 /*edge_filter=*/NotBackedge);
363
364 for (Node* n : reverse_post_order) {
365 if (!IsShapeConsumerOp(*n)) {
366 continue;
367 }
368
369 absl::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
370 if (!cluster.has_value()) {
371 continue;
372 }
373
374 auto input_belongs_to_same_cluster = [&](const Edge* e) {
375 return cluster == GetXlaClusterForNode(*e->src());
376 };
377
378 if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
379 continue;
380 }
381
382 VLOG(2) << "Declustering " << n->name()
383 << " because it is a root shape consumer";
384 RemoveFromXlaCluster(n);
385 }
386 return Status::OK();
387 }
388 } // namespace decluster_root_shape_consumers
389 } // namespace
390
Run(const GraphOptimizationPassOptions & options)391 Status PartiallyDeclusterPass::Run(
392 const GraphOptimizationPassOptions& options) {
393 // NB! In this pass we assume the only XLA-auto-clusterable operations that
394 // may have side effects are resource variable operations so we don't cluster
395 // those. The pass will have to be updated if this assumption becomes
396 // invalid.
397
398 Graph* graph = options.graph->get();
399
400 TF_RETURN_IF_ERROR(
401 reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
402 if (options.flib_def == nullptr) {
403 return errors::InvalidArgument(
404 "GraphOptimizationPassOptions::flib_def must be set for "
405 "PartiallyDeclusterPass.");
406 }
407 if (options.session_options == nullptr ||
408 options.session_options->env == nullptr) {
409 return errors::InvalidArgument(
410 "GraphOptimizationPassOptions::session_options::env must be set for "
411 "PartiallyDeclusterPass.");
412 }
413 TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
414 graph, options.flib_def, options.session_options->env));
415
416 TF_RETURN_IF_ERROR(
417 decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
418
419 return Status::OK();
420 }
421 } // namespace tensorflow
422