1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/partially_decluster_pass.h"
17 
18 #include "absl/algorithm/container.h"
19 #include "absl/container/flat_hash_set.h"
20 #include "absl/strings/str_cat.h"
21 #include "tensorflow/compiler/jit/device_util.h"
22 #include "tensorflow/compiler/jit/xla_cluster_util.h"
23 #include "tensorflow/compiler/tf2xla/const_analysis.h"
24 #include "tensorflow/compiler/tf2xla/xla_op_registry.h"
25 #include "tensorflow/core/common_runtime/function.h"
26 #include "tensorflow/core/framework/function.h"
27 #include "tensorflow/core/framework/memory_types.h"
28 #include "tensorflow/core/framework/node_def.pb.h"
29 #include "tensorflow/core/framework/op_kernel.h"
30 #include "tensorflow/core/graph/graph_node_util.h"
31 #include "tensorflow/core/lib/core/errors.h"
32 #include "tensorflow/core/public/version.h"
33 
34 namespace tensorflow {
35 namespace {
36 
NotBackedge(const Edge & edge)37 bool NotBackedge(const Edge& edge) { return !edge.src()->IsNextIteration(); }
38 
39 namespace reduce_device_to_host_copies {
FindNodesToDecluster(const Graph & graph,absl::flat_hash_set<Node * > * result,absl::Span<Node * const> post_order)40 Status FindNodesToDecluster(const Graph& graph,
41                             absl::flat_hash_set<Node*>* result,
42                             absl::Span<Node* const> post_order) {
43   // Find nodes that have at least one user outside their cluster that expects
44   // hostmem output.  These nodes should be cloned to outside the cluster to
45   // avoid the device-host copy we'd otherwise need.
46 
47   MemoryTypeVector input_mtypes, output_mtypes;
48 
49   for (Node* n : post_order) {
50     absl::optional<absl::string_view> from_cluster = GetXlaClusterForNode(*n);
51     if (!from_cluster) {
52       continue;
53     }
54 
55     // Assume the benefit of not outputting a larger tensor outweighs the
56     // benefit of this check.
57     // TODO(tpopp): Only apply this if the value being consumed is not output
58     // from the cluster to another consumer.
59     // TODO(tpopp): See if XlaRun can be modified to avoid this issue
60     // completely.
61     if (IsShapeConsumerOp(*n)) {
62       continue;
63     }
64     // We assume the only XLA-auto-clusterable operations with side effects are
65     // resource variable updates.  We can't execute these twice.
66     if (HasResourceInputOrOutput(*n)) {
67       continue;
68     }
69 
70     DeviceType device_type("");
71     TF_RETURN_IF_ERROR(
72         DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
73     TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
74                                           n->def(), &input_mtypes,
75                                           &output_mtypes));
76     for (const Edge* e : n->out_edges()) {
77       Node* dst = e->dst();
78 
79       if (e->IsControlEdge()) {
80         continue;
81       }
82 
83       bool edge_incurs_extra_device_to_host_copy;
84       if (output_mtypes[e->src_output()] == DEVICE_MEMORY) {
85         // If the output of the *TensorFlow* operation is in DEVICE_MEMORY then
86         // keep the node clustered -- XLA will also produce the output in device
87         // memory and we will get some benefit from clustering.
88         edge_incurs_extra_device_to_host_copy = false;
89       } else {
90         MemoryTypeVector dst_input_mtypes, dst_output_mtypes;
91         DeviceType dst_device_type("");
92         TF_RETURN_IF_ERROR(DeviceNameToDeviceType(dst->assigned_device_name(),
93                                                   &dst_device_type));
94         TF_RETURN_IF_ERROR(MemoryTypesForNode(graph.op_registry(), device_type,
95                                               dst->def(), &dst_input_mtypes,
96                                               &dst_output_mtypes));
97         edge_incurs_extra_device_to_host_copy =
98             dst_input_mtypes[e->dst_input()] == HOST_MEMORY;
99       }
100 
101       if (!edge_incurs_extra_device_to_host_copy) {
102         continue;
103       }
104 
105       // Check if `dst` is in a different cluster, unclustered, or about to be
106       // partially declustered (here we rely on the post-order traversal order).
107       // If yes, decluster `n` to avoid the device-to-host memcpy.
108       absl::optional<absl::string_view> dst_cluster =
109           result->count(dst) ? absl::nullopt : GetXlaClusterForNode(*dst);
110       if (from_cluster != dst_cluster) {
111         CHECK(result->insert(n).second);
112         break;
113       }
114     }
115   }
116   return Status::OK();
117 }
118 
PartiallyDeclusterNode(Graph * graph,Node * n)119 Status PartiallyDeclusterNode(Graph* graph, Node* n) {
120   absl::string_view cluster_name = *GetXlaClusterForNode(*n);
121   absl::InlinedVector<const Edge*, 6> out_edges_to_clone;
122   for (const Edge* out_edge : n->out_edges()) {
123     if (out_edge->IsControlEdge()) {
124       continue;
125     }
126 
127     Node* dst = out_edge->dst();
128     absl::optional<absl::string_view> dst_cluster_name =
129         GetXlaClusterForNode(*dst);
130     if (dst_cluster_name != cluster_name) {
131       out_edges_to_clone.push_back(out_edge);
132     }
133   }
134 
135   CHECK(!out_edges_to_clone.empty()) << n->DebugString();
136 
137   NodeDef ndef = n->def();
138   ndef.set_name(absl::StrCat(n->name(), "/declustered"));
139   MergeDebugInfo(NodeDebugInfo(n->def()), &ndef);
140   RemoveFromXlaCluster(&ndef);
141   Status s;
142   Node* cloned_node = graph->AddNode(ndef, &s);
143   cloned_node->set_assigned_device_name(n->assigned_device_name());
144   TF_RETURN_IF_ERROR(s);
145 
146   for (const Edge* in_edge : n->in_edges()) {
147     graph->AddEdge(in_edge->src(), in_edge->src_output(), cloned_node,
148                    in_edge->dst_input());
149   }
150 
151   for (const Edge* out_edge_to_clone : out_edges_to_clone) {
152     graph->AddEdge(cloned_node, out_edge_to_clone->src_output(),
153                    out_edge_to_clone->dst(), out_edge_to_clone->dst_input());
154     graph->RemoveEdge(out_edge_to_clone);
155   }
156 
157   if (n->out_edges().empty()) {
158     graph->RemoveNode(n);
159   }
160 
161   return Status::OK();
162 }
163 
164 // Clones nodes to outside their cluster to avoid device-to-host copies.  For
165 // instance, converts this:
166 //
167 //         .....
168 //           |
169 //           v
170 //      A_Clustered ====> C_Unclustered
171 //           |
172 //           v
173 //      B_Clustered
174 //
175 // to:
176 //
177 //         .....
178 //          | |
179 //          | +-------------+
180 //          |               |
181 //          v               v
182 //      A_Clustered   A_Unclustered ====> C_Unclustered
183 //           |
184 //           v
185 //      B_Clustered
186 //
187 // where the ===> arrow has a hostmem source and destination and would entail a
188 // device to host copy if the source and destination were not in the same XLA
189 // cluster.
PartiallyDeclusterGraph(Graph * graph)190 Status PartiallyDeclusterGraph(Graph* graph) {
191   // When deciding whether to decluster a particular node, we base our decision
192   // on if we've decided that some of its consumers have to be declustered too.
193   // Iterating the graph in post-order guarantees that consumers have been
194   // visited before producers.
195   std::vector<Node*> post_order;
196   GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
197                /*edge_filter=*/NotBackedge);
198 
199   absl::flat_hash_set<Node*> nodes_to_partially_decluster;
200   TF_RETURN_IF_ERROR(
201       FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
202 
203   if (VLOG_IS_ON(3)) {
204     for (Node* n : post_order) {
205       if (nodes_to_partially_decluster.count(n)) {
206         VLOG(3) << n->DebugString();
207       }
208     }
209   }
210 
211   for (Node* n : post_order) {
212     if (nodes_to_partially_decluster.count(n)) {
213       TF_RETURN_IF_ERROR(PartiallyDeclusterNode(graph, n));
214     }
215   }
216 
217   // Recompute post order since PartiallyDeclusterNode may have deleted nodes.
218   post_order.clear();
219   GetPostOrder(*graph, &post_order, /*stable_comparator=*/NodeComparatorName(),
220                /*edge_filter=*/NotBackedge);
221   nodes_to_partially_decluster.clear();
222   TF_RETURN_IF_ERROR(
223       FindNodesToDecluster(*graph, &nodes_to_partially_decluster, post_order));
224   CHECK(nodes_to_partially_decluster.empty());
225 
226   return Status::OK();
227 }
228 }  // namespace reduce_device_to_host_copies
229 
230 namespace reduce_recompilation {
IsIntraClusterEdge(const Edge & edge)231 bool IsIntraClusterEdge(const Edge& edge) {
232   absl::optional<absl::string_view> src_cluster_name =
233       GetXlaClusterForNode(*edge.src());
234   absl::optional<absl::string_view> dst_cluster_name =
235       GetXlaClusterForNode(*edge.dst());
236   return src_cluster_name.has_value() && src_cluster_name == dst_cluster_name;
237 }
238 
IsMustCompileDevice(const DeviceType & device_type)239 bool IsMustCompileDevice(const DeviceType& device_type) {
240   const XlaOpRegistry::DeviceRegistration* registration;
241   if (XlaOpRegistry::GetCompilationDevice(device_type.type(), &registration)) {
242     return registration->autoclustering_policy ==
243            XlaOpRegistry::AutoclusteringPolicy::kAlways;
244   }
245 
246   return false;
247 }
248 
MustCompileNode(const Node * n,bool * must_compile)249 Status MustCompileNode(const Node* n, bool* must_compile) {
250   DeviceType device_type("");
251   TF_RETURN_IF_ERROR(
252       DeviceNameToDeviceType(n->assigned_device_name(), &device_type));
253 
254   if (IsMustCompileDevice(device_type)) {
255     *must_compile = true;
256     return Status::OK();
257   }
258 
259   // We must compile `n` if it does not have a TensorFlow kernel.
260   *must_compile = !FindKernelDef(device_type, n->def(), nullptr, nullptr).ok();
261   return Status::OK();
262 }
263 
264 // Declusters nodes to reduce the number of times we think we need to recompile
265 // a TensorFlow graph.
266 //
267 // Abstractly, if we have a cluster of this form:
268 //
269 //   x0 = arg0
270 //   x1 = arg1
271 //     ...
272 //   shape = f(x0, x1, ...)
273 //   result = Reshape(input=<something>, new_shape=shape)
274 //
275 // then pulling `f` out of the cluster may reduce the number of compilations and
276 // will never increase the number of compilations.
277 //
278 // We may reduce the number of compilations if f is many to one.  For instance
279 // if f(x,y) = x-y then x=3,y=1 and x=4,y=2 will generate two different
280 // compilations if f is in the cluster but only one compilation if f is outside
281 // the cluster.
282 //
283 // Declustering f will increase the number of compilations only if f is a
284 // one-to-many "function" i.e. isn't a function at all.  RNG is one possible
285 // example, depending on how we look at it.  But we never create clusters where
286 // such f's would be marked as must-be-constant.
287 //
288 // We assume here that the extra repeated (repeated compared to a clustered f
289 // where it will always be constant folded) host-side computation of f does not
290 // regress performance in any significant manner.  We will have to revisit this
291 // algorithm with a more complex cost model if this assumption turns out to be
292 // incorrect.
PartiallyDeclusterGraph(Graph * graph,const FunctionLibraryDefinition * flib_def,Env * env)293 Status PartiallyDeclusterGraph(Graph* graph,
294                                const FunctionLibraryDefinition* flib_def,
295                                Env* env) {
296   std::vector<bool> compile_time_const_nodes(graph->num_node_ids());
297   OptimizerOptions opts;
298   auto pflr = absl::make_unique<ProcessFunctionLibraryRuntime>(
299       nullptr, env, /*config=*/nullptr, TF_GRAPH_DEF_VERSION, flib_def, opts);
300   FunctionLibraryRuntime* lib_runtime =
301       pflr->GetFLR(ProcessFunctionLibraryRuntime::kDefaultFLRDevice);
302   TF_RETURN_IF_ERROR(BackwardsConstAnalysis(*graph, nullptr,
303                                             &compile_time_const_nodes,
304                                             lib_runtime, IsIntraClusterEdge));
305 
306   std::vector<Node*> rpo;
307   GetReversePostOrder(*graph, &rpo, /*stable_comparator=*/NodeComparatorName(),
308                       /*edge_filter=*/NotBackedge);
309   for (Node* n : rpo) {
310     if (!compile_time_const_nodes[n->id()]) {
311       continue;
312     }
313 
314     absl::string_view cluster_name = *GetXlaClusterForNode(*n);
315     bool node_on_cluster_edge =
316         absl::c_all_of(n->in_edges(), [&](const Edge* e) {
317           absl::optional<absl::string_view> incoming_cluster =
318               GetXlaClusterForNode(*e->src());
319           return !incoming_cluster || *incoming_cluster != cluster_name;
320         });
321 
322     // We don't want to decluster F in a graph like
323     //
324     //   Input -> OP -> Shape -> F -> Reshape
325     //
326     // Doing so will break up the cluster.  Even if we were okay with breaking
327     // up the cluster we will at least have to relabel the two clusters to have
328     // different cluster names.
329     //
330     // We may want to revisit this in the future: we may have cases where OP is
331     // a small computation that does not benefit from XLA while XLA can optimize
332     // everything that follows the Reshape.  In these cases it may be wise to
333     // remove Input, OP, Shape and F from the cluster, if F is a many-to-one
334     // function.
335     //
336     // Note that we do do the right thing for graphs like:
337     //
338     //   Input -> F0 -> F1 -> Reshape
339     //
340     // Since we iterate in RPO, we'll first encounter F0, decluster it, then
341     // encounter F1, decluster it and so on.
342     if (node_on_cluster_edge) {
343       bool must_compile_node;
344       TF_RETURN_IF_ERROR(MustCompileNode(n, &must_compile_node));
345       if (!must_compile_node) {
346         VLOG(3) << "Declustering must-be-constant node " << n->name();
347         RemoveFromXlaCluster(n);
348       }
349     }
350   }
351 
352   return Status::OK();
353 }
354 }  // namespace reduce_recompilation
355 
356 namespace decluster_root_shape_consumers {
357 
PartiallyDeclusterGraph(Graph * graph)358 Status PartiallyDeclusterGraph(Graph* graph) {
359   std::vector<Node*> reverse_post_order;
360   GetReversePostOrder(*graph, &reverse_post_order,
361                       /*stable_comparator=*/NodeComparatorName(),
362                       /*edge_filter=*/NotBackedge);
363 
364   for (Node* n : reverse_post_order) {
365     if (!IsShapeConsumerOp(*n)) {
366       continue;
367     }
368 
369     absl::optional<absl::string_view> cluster = GetXlaClusterForNode(*n);
370     if (!cluster.has_value()) {
371       continue;
372     }
373 
374     auto input_belongs_to_same_cluster = [&](const Edge* e) {
375       return cluster == GetXlaClusterForNode(*e->src());
376     };
377 
378     if (absl::c_any_of(n->in_edges(), input_belongs_to_same_cluster)) {
379       continue;
380     }
381 
382     VLOG(2) << "Declustering " << n->name()
383             << " because it is a root shape consumer";
384     RemoveFromXlaCluster(n);
385   }
386   return Status::OK();
387 }
388 }  // namespace decluster_root_shape_consumers
389 }  // namespace
390 
Run(const GraphOptimizationPassOptions & options)391 Status PartiallyDeclusterPass::Run(
392     const GraphOptimizationPassOptions& options) {
393   // NB!  In this pass we assume the only XLA-auto-clusterable operations that
394   // may have side effects are resource variable operations so we don't cluster
395   // those.  The pass will have to be updated if this assumption becomes
396   // invalid.
397 
398   Graph* graph = options.graph->get();
399 
400   TF_RETURN_IF_ERROR(
401       reduce_device_to_host_copies::PartiallyDeclusterGraph(graph));
402   if (options.flib_def == nullptr) {
403     return errors::InvalidArgument(
404         "GraphOptimizationPassOptions::flib_def must be set for "
405         "PartiallyDeclusterPass.");
406   }
407   if (options.session_options == nullptr ||
408       options.session_options->env == nullptr) {
409     return errors::InvalidArgument(
410         "GraphOptimizationPassOptions::session_options::env must be set for "
411         "PartiallyDeclusterPass.");
412   }
413   TF_RETURN_IF_ERROR(reduce_recompilation::PartiallyDeclusterGraph(
414       graph, options.flib_def, options.session_options->env));
415 
416   TF_RETURN_IF_ERROR(
417       decluster_root_shape_consumers::PartiallyDeclusterGraph(graph));
418 
419   return Status::OK();
420 }
421 }  // namespace tensorflow
422