1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/jit/encapsulate_util.h"
17 
18 #include <algorithm>
19 #include <iterator>
20 
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/strings/str_cat.h"
24 #include "absl/types/optional.h"
25 #include "tensorflow/compiler/jit/shape_inference.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
27 #include "tensorflow/core/framework/node_def_util.h"
28 #include "tensorflow/core/graph/node_builder.h"
29 #include "tensorflow/core/protobuf/error_codes.pb.h"
30 #include "tensorflow/stream_executor/lib/statusor.h"
31 
32 using stream_executor::port::StatusOr;
33 
34 namespace tensorflow {
35 
36 namespace {
37 
38 // Returns string attribute value for the node if the attribute is present,
39 // otherwise returns empty optional value.
GetStringAttr(const Node & n,const string & attr_name)40 absl::optional<string> GetStringAttr(const Node& n, const string& attr_name) {
41   auto attr = n.attrs().Find(attr_name);
42   if (!attr) {
43     return absl::nullopt;
44   } else {
45     return attr->s();
46   }
47 }
48 
49 // Adds a value to the node's list attribute.
50 template <typename T>
AppendToListAttr(Node * n,const string & attr_name,const string & value)51 Status AppendToListAttr(Node* n, const string& attr_name, const string& value) {
52   std::vector<T> attr_value;
53   Status s = GetNodeAttr(n->attrs(), attr_name, &attr_value);
54   if (!s.ok() && s.code() != error::NOT_FOUND) {
55     return s;
56   }
57 
58   n->ClearAttr(attr_name);
59   attr_value.push_back(value);
60   n->AddAttr(attr_name, attr_value);
61   return Status::OK();
62 }
63 
64 // Replaces attribute value.
65 template <typename T>
ReplaceAttr(Node * n,const string & attr_name,const T & value)66 void ReplaceAttr(Node* n, const string& attr_name, const T& value) {
67   n->ClearAttr(attr_name);
68   n->AddAttr(attr_name, value);
69 }
70 
71 // Step 1 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
72 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)73 Status PreprocessControlEdgesBetweenOutsideCompilations(
74     Graph* g, const string& outside_compilation_attr_name) {
75   // Gather edges to remove. We should not remove the edge while iterating.
76   std::vector<const Edge*> edges_to_remove;
77   for (const Edge* e : g->edges()) {
78     if (!e->IsControlEdge()) {
79       continue;
80     }
81 
82     auto src_outside_compilation =
83         GetStringAttr(*e->src(), outside_compilation_attr_name);
84     auto dst_outside_compilation =
85         GetStringAttr(*e->dst(), outside_compilation_attr_name);
86 
87     if (src_outside_compilation && dst_outside_compilation) {
88       if (*src_outside_compilation != *dst_outside_compilation) {
89         // Case 1a: outside compilation to outside compilation control edge.
90         edges_to_remove.push_back(e);
91 
92         TF_RETURN_IF_ERROR(AppendToListAttr<string>(
93             e->dst(), kXlaControlDependenciesWithinXlaClusterAttrName,
94             e->src()->name()));
95       }
96     } else if (src_outside_compilation && !dst_outside_compilation) {
97       // Case 1b: outside compilation to its XLA computation control edge.
98       ReplaceAttr(e->src(), kXlaConnectedToXlaComputationAttrName, true);
99     } else if (!src_outside_compilation && dst_outside_compilation) {
100       // Case 1b: XLA computation to outside compilation in it control edge.
101       ReplaceAttr(e->dst(), kXlaConnectedFromXlaComputationAttrName, true);
102     }
103   }
104 
105   for (auto e : edges_to_remove) {
106     g->RemoveEdge(e);
107   }
108   return Status::OK();
109 }
110 
111 // Step 2 for `PreprocessEdgesBetweenOutsideCompilations`. See comments of
112 // `PreprocessEdgesBetweenOutsideCompilations` for details.
PreprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)113 Status PreprocessDataEdgesBetweenOutsideCompilations(
114     Graph* g, const string& outside_compilation_attr_name) {
115   // Gather edges between outside compilation and host computation. Notice that
116   // we do not store `Edge*` directly because we remove some nodes while adding
117   // Identity nodes, and those Edge pointers might be invalidated.
118   struct EdgeInfo {
119     int dst_input, dst_node_id;
120   };
121   std::vector<EdgeInfo> edges;
122   for (const Edge* e : g->edges()) {
123     if (e->IsControlEdge()) {
124       continue;
125     }
126 
127     auto src_outside_compilation =
128         GetStringAttr(*e->src(), outside_compilation_attr_name);
129     auto dst_outside_compilation =
130         GetStringAttr(*e->dst(), outside_compilation_attr_name);
131 
132     if (src_outside_compilation && dst_outside_compilation &&
133         *src_outside_compilation != *dst_outside_compilation) {
134       edges.push_back(EdgeInfo{e->dst_input(), e->dst()->id()});
135       VLOG(4) << "Oc -> oc edge: " << e->DebugString();
136     }
137   }
138 
139   // Remove the edge from host to outside compilation. Add a placeholder as
140   // outside compilation node input.
141   std::map<std::pair<string, int>, Node*> placeholders;
142   for (int i = 0, end = edges.size(); i < end; i++) {
143     Node* dst = g->FindNodeId(edges[i].dst_node_id);
144     const Edge* e;
145     TF_RETURN_IF_ERROR(dst->input_edge(edges[i].dst_input, &e));
146     Node* src = e->src();
147     int src_output = e->src_output(), dst_input = e->dst_input();
148     g->RemoveEdge(e);
149 
150     // Find or create placeholder node.
151     string new_name =
152         absl::StrCat(src->name(), "_oc_to_oc_placeholder_", src_output);
153     auto placeholder_index = std::make_pair(src->name(), src_output);
154     auto iter = placeholders.find(placeholder_index);
155     Node* placeholder_node;
156     if (iter == placeholders.end()) {
157       NodeDefBuilder placeholder_builder(new_name, "Placeholder");
158       placeholder_builder.Attr("dtype", src->output_type(src_output));
159       string outside_compilation_attr;
160       TF_RETURN_IF_ERROR(GetNodeAttr(dst->attrs(),
161                                      outside_compilation_attr_name,
162                                      &outside_compilation_attr));
163       placeholder_builder.Attr(outside_compilation_attr_name,
164                                outside_compilation_attr);
165       placeholder_builder.Attr(kOutsideCompilationOriginalNodeAttrName,
166                                src->name());
167       placeholder_builder.Attr(kOutsideCompilationSrcOutputAttrName,
168                                src_output);
169       NodeDef placeholder_def;
170       TF_RETURN_IF_ERROR(placeholder_builder.Finalize(&placeholder_def));
171       Status s;
172       placeholder_node = g->AddNode(placeholder_def, &s);
173       TF_RETURN_IF_ERROR(s);
174       placeholders[placeholder_index] = placeholder_node;
175     } else {
176       placeholder_node = iter->second;
177     }
178     g->AddEdge(placeholder_node, 0, dst, dst_input);
179 
180     // Replace `e->dst()` because its input node changed.
181     NodeDef new_def = dst->def();
182     *new_def.mutable_input(dst_input) = placeholder_node->name();
183     TF_ASSIGN_OR_RETURN(Node * dst_replace_node, ReplaceNode(g, dst, new_def));
184 
185     // Other edge in `edges` might have `e->dst()` as src or dst
186     // node. Before removing `e->dst()`, replace those edges with
187     // corresponding edges for `dst_replace_node`.
188     for (int j = i + 1, end = edges.size(); j < end; j++) {
189       if (edges[j].dst_node_id == edges[i].dst_node_id) {
190         edges[j].dst_node_id = dst_replace_node->id();
191       }
192     }
193   }
194   return Status::OK();
195 }
196 
197 // Step 1 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
198 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessDataEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)199 Status PostprocessDataEdgesBetweenOutsideCompilations(
200     Graph* g, const string& outside_compilation_attr_name) {
201   // Gather all outside compilation to outside compilation nodes.
202   std::vector<Node*> placeholder_nodes;
203   for (Node* n : g->nodes()) {
204     if (n->type_string() == "Placeholder" &&
205         HasNodeAttr(n->def(), kOutsideCompilationOriginalNodeAttrName)) {
206       placeholder_nodes.push_back(n);
207     }
208   }
209 
210   // Remove the placeholder nodes, and reconnect original edge.
211   auto node_name_index = g->BuildNodeNameIndex();
212   for (auto n : placeholder_nodes) {
213     string node_name;
214     int node_src_output;
215     TF_RETURN_IF_ERROR(GetNodeAttr(
216         n->attrs(), kOutsideCompilationOriginalNodeAttrName, &node_name));
217     TF_RETURN_IF_ERROR(GetNodeAttr(
218         n->attrs(), kOutsideCompilationSrcOutputAttrName, &node_src_output));
219     auto iter = node_name_index.find(node_name);
220     if (iter == node_name_index.end()) {
221       return errors::Internal(
222           "Cannot find original node for oc -> host placeholder node ",
223           node_name);
224     }
225 
226     // Change all usage node to use the original node instead.
227     Node* original_node = iter->second;
228     std::vector<const Edge*> control_edges;
229     std::vector<OutEdgeInfo> data_edges;
230     for (auto e : n->out_edges()) {
231       if (e->IsControlEdge()) {
232         control_edges.push_back(e);
233       } else {
234         data_edges.push_back({e->dst(), e->src_output(), e->dst_input()});
235       }
236     }
237     for (const Edge* e : control_edges) {
238       g->AddControlEdge(original_node, e->dst());
239       g->RemoveEdge(e);
240     }
241     for (int i = 0, end = data_edges.size(); i < end; i++) {
242       Node* dst = data_edges[i].dst;
243       NodeDef new_def = dst->def();
244       int dst_input = data_edges[i].dst_input;
245       *new_def.mutable_input(dst_input) =
246           absl::StrCat(original_node->name(), ":", node_src_output);
247       TF_ASSIGN_OR_RETURN(Node * replace_node, ReplaceNode(g, dst, new_def));
248 
249       const Edge* edge_to_replace = nullptr;
250       TF_RETURN_IF_ERROR(replace_node->input_edge(dst_input, &edge_to_replace));
251       g->RemoveEdge(edge_to_replace);
252       g->AddEdge(original_node, node_src_output, replace_node, dst_input);
253 
254       // Other edges might have `dst` as dst node. Update those edges with
255       // `replace_node`.
256       for (int j = i + 1, end = data_edges.size(); j < end; j++) {
257         if (data_edges[j].dst == dst) {
258           data_edges[j].dst = replace_node;
259         }
260       }
261 
262       // Other placeholder node might have `dst` as original node. Update
263       // `node_name_index` with `replace_node`.
264       node_name_index[replace_node->name()] = replace_node;
265     }
266 
267     // Remove placeholder node.
268     g->RemoveNode(n);
269   }
270   return Status::OK();
271 }
272 
273 // Step 2 for `PostprocessEdgesBetweenOutsideCompilations`. See comments of
274 // `PostprocessEdgesBetweenOutsideCompilations` for details.
PostprocessControlEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)275 Status PostprocessControlEdgesBetweenOutsideCompilations(
276     Graph* g, const string& outside_compilation_attr_name) {
277   auto node_name_index = g->BuildNodeNameIndex();
278 
279   // Reconnect outside compilation to outside compilation control edge.
280   for (Node* n : g->nodes()) {
281     std::vector<string> control_deps;
282     Status s =
283         GetNodeAttr(n->attrs(), kXlaControlDependenciesWithinXlaClusterAttrName,
284                     &control_deps);
285     if (!s.ok()) {
286       if (s.code() != error::NOT_FOUND) {
287         return s;
288       } else {
289         continue;
290       }
291     } else {
292       n->ClearAttr(kXlaControlDependenciesWithinXlaClusterAttrName);
293       for (const string& control_input : control_deps) {
294         auto iter = node_name_index.find(control_input);
295         if (iter == node_name_index.end()) {
296           return errors::Internal("Cannot find original node for ",
297                                   control_input);
298         }
299         g->AddControlEdge(iter->second, n);
300       }
301     }
302   }
303   return Status::OK();
304 }
305 }  // namespace
306 
307 const char kXlaInferredShapesAttrName[] = "_xla_inferred_shapes";
308 
309 const char kXlaConnectedToXlaComputationAttrName[] =
310     "_xla_connected_to_xla_computation";
311 const char kXlaConnectedFromXlaComputationAttrName[] =
312     "_xla_connected_from_xla_computation";
313 const char kOutsideCompilationOriginalNodeAttrName[] =
314     "_xla_oc_to_oc_node_name";
315 const char kOutsideCompilationSrcOutputAttrName[] = "_xla_oc_to_oc_src_output";
316 const char kXlaControlDependenciesWithinXlaClusterAttrName[] =
317     "_xla_control_dependencies_within_xla_cluster";
318 const char kXlaIsLiftedArgAttrName[] = "_xla_is_lifted_arg";
319 const char kXlaLiftedArgOutsideCompilationAttrName[] = "_xla_lifted_arg_oc";
320 const char kXlaOutsideCompilationInputsAttrName[] = "_xla_oc_inputs";
321 const char kXlaIsPlaceholderForArg[] = "_xla_is_placeholder_for_arg";
322 
PerformStaticShapeInferenceBeforeEncapsulation(Graph * g)323 Status PerformStaticShapeInferenceBeforeEncapsulation(Graph* g) {
324   // Perform shape inference.
325   std::map<int, InferredShape> arg_shapes;
326   GraphShapeInfo shape_info;
327   TF_RETURN_IF_ERROR(
328       InferShapes(g, arg_shapes, /*fnlib_def=*/nullptr, &shape_info));
329 
330   // Add attribute for output shapes.
331   auto node_name_index = g->BuildNodeNameIndex();
332   for (auto iter : shape_info) {
333     std::vector<PartialTensorShape> output_shapes;
334     std::transform(iter.second.begin(), iter.second.end(),
335                    std::back_inserter(output_shapes),
336                    [](const InferredShape& inferred_shape) {
337                      return inferred_shape.shape;
338                    });
339     Node* n = node_name_index[iter.first];
340     n->AddAttr(kXlaInferredShapesAttrName, output_shapes);
341   }
342 
343   return Status::OK();
344 }
345 
346 StatusOr<std::unique_ptr<absl::flat_hash_map<string, std::vector<string>>>>
OutsideCompilationClusterDependencies(const Graph * g,const string & outside_compilation_attr_name)347 OutsideCompilationClusterDependencies(
348     const Graph* g, const string& outside_compilation_attr_name) {
349   auto cluster_deps = absl::make_unique<
350       absl::flat_hash_map<string, absl::flat_hash_set<string>>>();
351 
352   for (const Edge* e : g->edges()) {
353     auto src_outside_compilation =
354         GetStringAttr(*e->src(), outside_compilation_attr_name);
355     auto dst_outside_compilation =
356         GetStringAttr(*e->dst(), outside_compilation_attr_name);
357 
358     if (src_outside_compilation && dst_outside_compilation &&
359         *src_outside_compilation != *dst_outside_compilation) {
360       auto dst_deps_it = cluster_deps->find(*dst_outside_compilation);
361       if (dst_deps_it == cluster_deps->end()) {
362         cluster_deps->insert(std::make_pair(
363             *dst_outside_compilation,
364             absl::flat_hash_set<string>({*src_outside_compilation})));
365       } else {
366         dst_deps_it->second.insert(*src_outside_compilation);
367       }
368     }
369   }
370 
371   auto cluster_deps_ordered =
372       absl::make_unique<absl::flat_hash_map<string, std::vector<string>>>();
373 
374   for (auto it = cluster_deps->begin(); it != cluster_deps->end(); it++) {
375     std::vector<string> ordered_deps(it->second.begin(), it->second.end());
376     std::sort(ordered_deps.begin(), ordered_deps.end());
377     cluster_deps_ordered->insert(std::make_pair(it->first, ordered_deps));
378   }
379 
380   return std::move(cluster_deps_ordered);
381 }
382 
PreprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)383 Status PreprocessEdgesBetweenOutsideCompilations(
384     Graph* g, const string& outside_compilation_attr_name) {
385   // Remove edges from source node to outside compilation nodes, and edges
386   // from outside compilation nodes to sink node.
387   std::vector<const Edge*> edges_to_remove;
388   for (const Edge* e : g->source_node()->out_edges()) {
389     if (HasNodeAttr(e->dst()->def(), outside_compilation_attr_name)) {
390       edges_to_remove.push_back(e);
391     }
392   }
393   for (const Edge* e : g->sink_node()->in_edges()) {
394     if (HasNodeAttr(e->src()->def(), outside_compilation_attr_name)) {
395       edges_to_remove.push_back(e);
396     }
397   }
398   for (auto e : edges_to_remove) {
399     g->RemoveEdge(e);
400   }
401 
402   TF_RETURN_IF_ERROR(PreprocessControlEdgesBetweenOutsideCompilations(
403       g, outside_compilation_attr_name));
404   TF_RETURN_IF_ERROR(PreprocessDataEdgesBetweenOutsideCompilations(
405       g, outside_compilation_attr_name));
406   return Status::OK();
407 }
408 
PostprocessEdgesBetweenOutsideCompilations(Graph * g,const string & outside_compilation_attr_name)409 Status PostprocessEdgesBetweenOutsideCompilations(
410     Graph* g, const string& outside_compilation_attr_name) {
411   TF_RETURN_IF_ERROR(PostprocessDataEdgesBetweenOutsideCompilations(
412       g, outside_compilation_attr_name));
413   TF_RETURN_IF_ERROR(PostprocessControlEdgesBetweenOutsideCompilations(
414       g, outside_compilation_attr_name));
415   return Status::OK();
416 }
417 
418 }  // namespace tensorflow
419