1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23 
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 
42 namespace tensorflow {
43 
44 namespace {
45 
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47   if (id.node_name().empty()) {
48     return errors::InvalidArgument("TensorId node_name must be non-empty");
49   }
50   if (id.output_index() < 0) {
51     return errors::InvalidArgument("TensorId output_index must be positive");
52   }
53   return Status::OK();
54 }
55 
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57                            std::set<string>* names) {
58   if (!name.empty()) {
59     if (!names->insert(name).second) {
60       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61     }
62   }
63   return Status::OK();
64 }
65 
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67                                    const std::set<string>& names) {
68   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69   // since that will cause a collision in codegen symbols.
70   for (const string& name : names) {
71     const string name_data(name + "_data");
72     if (names.find(name_data) != names.end()) {
73       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74                                      " and ", name_data);
75     }
76   }
77   return Status::OK();
78 }
79 
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83                                const FunctionLibraryDefinition* lookup_fld,
84                                FunctionLibraryDefinition* fld) {
85   for (Node* n : g->op_nodes()) {
86     for (const auto& associated_function :
87          GetAssociatedFunctions(*n, lookup_fld)) {
88       switch (associated_function.type()) {
89         case AssociatedFunctionInfo::kFunctionCallNode: {
90           const FunctionDef* fdef =
91               lookup_fld->Find(associated_function.func_name());
92           if (!fdef) {
93             return errors::Internal(
94                 "Cannot find function ", associated_function.func_name(),
95                 " for function call node ", n->DebugString());
96           }
97           TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98           break;
99         }
100         case AssociatedFunctionInfo::kSymbolicGradient:
101         case AssociatedFunctionInfo::kFunctionAttr:
102           break;
103       }
104     }
105   }
106   return Status::OK();
107 }
108 
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112     Graph* g,
113     const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114   // Collect all _Arg nodes.
115   std::unordered_map<int, Node*> arg_nodes;
116   for (Node* n : g->op_nodes()) {
117     if (n->IsArg()) {
118       int index;
119       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120       arg_nodes[index] = n;
121     }
122   }
123 
124   for (const auto& iter : const_input_index_to_node) {
125     int arg_index = iter.first;
126     NodeDef const_def = iter.second->def();
127     const_def.set_name(g->NewName(const_def.name()));
128     Status s;
129     Node* const_node = g->AddNode(const_def, &s);
130     TF_RETURN_IF_ERROR(s);
131 
132     Node* arg_node = arg_nodes[arg_index];
133 
134     // Collect all usages of the _Arg node.
135     struct OutEdgeInfo {
136       int dst_node_id, dst_input;
137     };
138     std::vector<OutEdgeInfo> usages;
139     for (const Edge* e : arg_node->out_edges()) {
140       if (e->IsControlEdge()) {
141         continue;
142       }
143       usages.push_back({e->dst()->id(), e->dst_input()});
144     }
145 
146     for (int i = 0, end = usages.size(); i < end; i++) {
147       // Make a copy of `usage_node`, and change its input to const node.
148       Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149       NodeDef replace_def = usage_node->def();
150       *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151       TF_ASSIGN_OR_RETURN(Node * replace_node,
152                           ReplaceNode(g, usage_node, replace_def));
153       const Edge* usage_edge;
154       TF_RETURN_IF_ERROR(
155           replace_node->input_edge(usages[i].dst_input, &usage_edge));
156       g->RemoveEdge(usage_edge);
157       g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158 
159       // Later entries in `usages` might have `usage_node` as dst node, but
160       // `usage_node` is removed. Replace such entries with `replace_node`.
161       for (int j = i + 1, end = usages.size(); j < end; j++) {
162         if (usages[j].dst_node_id == usages[i].dst_node_id) {
163           usages[j].dst_node_id = replace_node->id();
164         }
165       }
166     }
167   }
168   return Status::OK();
169 }
170 
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175     Node* n, const string& attr_name,
176     const std::unordered_map<int, const Node*>& const_input_index_to_node,
177     const FunctionLibraryDefinition* lookup_fld,
178     FunctionLibraryDefinition* fld) {
179   // Instantiate the function.
180   NameAttrList func_attr;
181   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182   const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183   if (!fdef) {
184     return errors::Internal("Cannot find function ", func_attr.name(),
185                             " for node ", n->name());
186   }
187   std::unique_ptr<FunctionBody> fbody;
188   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189       *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
190 
191   // Rewrite _Arg usages with Const node.
192   Graph* func_graph = fbody->graph;
193   TF_RETURN_IF_ERROR(
194       ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
195 
196   // Save rewritten function.
197   FunctionDef replace_fdef;
198   string new_func_name =
199       fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
200   TF_RETURN_IF_ERROR(
201       GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
202   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
203       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
204 
205   // Change the node to use rewritten function.
206   func_attr.set_name(new_func_name);
207   n->ClearAttr(attr_name);
208   n->AddAttr(attr_name, func_attr);
209 
210   TF_RETURN_IF_ERROR(fld->AddFunctionDef(
211       replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
212 
213   // Copy associated functions.
214   TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
215 
216   return Status::OK();
217 }
218 
219 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
220 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)221 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
222                                 const FunctionLibraryDefinition* lookup_fld,
223                                 FunctionLibraryDefinition* fld) {
224   // Notice that first input for If node is predicate; other inputs are function
225   // inputs.
226   std::unordered_map<int, const Node*> const_input_index_to_node;
227   for (int i = 1; i < if_node->num_inputs(); i++) {
228     const Node* input_node;
229     TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
230     if (input_node->type_string() == "Const") {
231       const_input_index_to_node[i - 1] = input_node;
232     }
233   }
234   if (const_input_index_to_node.empty()) {
235     return Status::OK();
236   }
237 
238   // Rewrite "then_branch" and "else_branch" function, replace usage of those
239   // _Arg nodes with corresponding const node.
240   for (const auto& attr_name :
241        std::vector<string>{"then_branch", "else_branch"}) {
242     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
243         if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
244   }
245 
246   return Status::OK();
247 }
248 
249 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
250 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)251 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
252                                    const FunctionLibraryDefinition* lookup_fld,
253                                    FunctionLibraryDefinition* fld) {
254   // For "While" node, we should only replace _Arg nodes which are loop
255   // invariants. For such _Arg nodes, the return value's input will come
256   // directly from the corresponding arg.
257   std::unordered_map<int, const Node*> const_input_index_to_node;
258   NameAttrList body_attr;
259   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
260   const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
261   if (!body_func) {
262     return errors::Internal("Cannot find body function ", body_attr.name(),
263                             " for While node ", while_node->name());
264   }
265   for (int i = 0; i < while_node->num_inputs(); i++) {
266     const Node* input_node;
267     TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
268     if (input_node->type_string() != "Const") {
269       continue;
270     }
271 
272     // Check if i-th retval's input comes from i-th arg directly.
273     // For resource variable input of While nodes, TF2XLA convention is to place
274     // them at the end of all inputs (after all data inputs), and *not* return
275     // them. So number of While node inputs might be larger than number of its
276     // outputs.
277     if (i >= body_func->signature().output_arg_size()) {
278       continue;
279     }
280     const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
281     auto output_arg_input = body_func->ret().find(output_arg.name());
282     if (output_arg_input == body_func->ret().end()) {
283       return errors::Internal("Cannot find input for output arg ",
284                               output_arg.name(), " in function ",
285                               body_attr.name());
286     }
287     const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
288     if (output_arg_input->second != input_arg.name()) {
289       continue;
290     }
291 
292     const_input_index_to_node[i] = input_node;
293   }
294   if (const_input_index_to_node.empty()) {
295     return Status::OK();
296   }
297 
298   // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
299   // corresponding const node.
300   for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
301     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
302         while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
303   }
304   return Status::OK();
305 }
306 
307 }  // namespace
308 
309 const char kTpuReplicateAttrName[] = "_tpu_replicate";
310 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
311 
ValidateConfig(const tf2xla::Config & config)312 Status ValidateConfig(const tf2xla::Config& config) {
313   std::set<string> names;
314   for (const tf2xla::Feed& feed : config.feed()) {
315     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
316     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
317     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
318   }
319   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
320   names.clear();
321   for (const tf2xla::Fetch& fetch : config.fetch()) {
322     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
323     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
324   }
325   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
326   if (config.fetch().empty()) {
327     return errors::InvalidArgument("fetches must be specified");
328   }
329   return Status::OK();
330 }
331 
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)332 Status AddPlaceholdersForFeeds(
333     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
334     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
335   struct PlaceholderInfo {
336     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
337     string placeholder_name;
338     DataType data_type = DT_INVALID;
339   };
340 
341   // Put each fed tensor into a map by name:port. A map is used for determinism
342   // when creating placeholders (genrules want deterministic output).
343   std::map<string, PlaceholderInfo> placeholder_info;
344   for (int i = 0; i < config.feed_size(); ++i) {
345     const tf2xla::Feed* feed = &config.feed(i);
346     const string name_port = TensorIdToString(feed->id());
347     PlaceholderInfo& info = placeholder_info[name_port];
348     info.feed = feed;
349     info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
350                                          "/", feed->id().node_name());
351     (*feed_remapping)[name_port] = info.placeholder_name;
352   }
353 
354   // Verify node exists and determine data type.
355   std::unordered_map<string, const NodeDef*> name_to_node;
356   for (int i = 0; i < graph_def->node_size(); ++i) {
357     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
358   }
359   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
360     PlaceholderInfo& info = it->second;
361     const tf2xla::TensorId& feed_id = info.feed->id();
362 
363     // Find the existing node and determine data type.
364     auto node_it = name_to_node.find(feed_id.node_name());
365     if (node_it == name_to_node.end()) {
366       return errors::NotFound("Can't find feed node: ",
367                               TensorIdToString(feed_id));
368     }
369     const NodeDef* existing = node_it->second;
370 
371     if (info.feed->type() != DT_INVALID) {
372       info.data_type = info.feed->type();
373     } else {
374       // Build the node in order to infer its type.
375 
376       // Must first add default attrs as well, so do this in a copied GraphDef.
377       GraphDef gd;
378       *gd.mutable_versions() = graph_def->versions();
379       *gd.add_node() = *existing;
380       MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
381       TF_RETURN_IF_ERROR(
382           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
383 
384       // Now build the node from the copied node def.
385       Graph g(op_registry);
386       g.set_versions(graph_def->versions());
387       Status status;
388       Node* feed_node = g.AddNode(gd.node(0), &status);
389       TF_RETURN_IF_ERROR(status);
390 
391       if (info.feed->id().output_index() < feed_node->num_outputs()) {
392         info.data_type =
393             BaseType(feed_node->output_type(info.feed->id().output_index()));
394       } else {
395         return errors::InvalidArgument(
396             "Invalid output_index ", info.feed->id().output_index(),
397             " for feed node ", info.feed->id().node_name());
398       }
399     }
400   }
401 
402   // Create placeholders. Note that we could avoid creating a placeholder for
403   // feeds which are already placeholders, but we omit that to avoid more cases
404   // in this code.
405   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
406     const PlaceholderInfo& info = it->second;
407     // TODO(shikharagarwal): Add original node information.
408     NodeDef* d = graph_def->add_node();
409     d->set_name(info.placeholder_name);
410     d->set_op("Placeholder");
411     auto& attr_map = *d->mutable_attr();
412     attr_map["dtype"].set_type(info.data_type);
413     *attr_map["shape"].mutable_shape() = info.feed->shape();
414   }
415 
416   // Rewrite references to the fed tensors to refer to the placeholder.
417   for (int i = 0; i < graph_def->node_size(); ++i) {
418     NodeDef* node_def = graph_def->mutable_node(i);
419     for (int j = 0; j < node_def->input_size(); ++j) {
420       auto id = ParseTensorName(node_def->input(j));
421       auto it = placeholder_info.find(id.ToString());
422       if (it != placeholder_info.end()) {
423         node_def->set_input(j, it->second.placeholder_name);
424       }
425     }
426   }
427 
428   return Status::OK();
429 }
430 
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)431 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
432                          GraphDef* out) {
433   *out = in;
434   out->clear_node();
435 
436   // Tensors needed for feeding.
437   std::set<std::pair<string, int>> feed_tensors;
438   for (const tf2xla::Feed& feed : config.feed()) {
439     feed_tensors.insert(
440         std::make_pair(feed.id().node_name(), feed.id().output_index()));
441   }
442 
443   // Maps node name to reachability.
444   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
445   for (const NodeDef& node : in.node()) {
446     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
447   }
448 
449   // Traverse.
450   std::queue<string> name_queue;
451   for (int i = 0; i < config.fetch_size(); ++i) {
452     name_queue.push(config.fetch(i).id().node_name());
453   }
454   while (!name_queue.empty()) {
455     const string name = name_queue.front();
456     name_queue.pop();
457 
458     auto find_it = node_by_name.find(name);
459     if (find_it == node_by_name.end()) {
460       return errors::InvalidArgument("While pruning graph, node ", name,
461                                      " needed but not found in the graph.");
462     }
463     auto& map_entry = find_it->second;
464     if (map_entry.first) {
465       continue;
466     }
467     map_entry.first = true;
468 
469     // Push input nodes of the currently visited node to name_queue.
470     for (const string& in_edge : map_entry.second->input()) {
471       auto id = ParseTensorName(in_edge);
472       const string node_name = string(id.first);
473       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
474           feed_tensors.end()) {
475         name_queue.push(node_name);
476       } else {
477         // The input tensor is from an edge that is being fed. Therefore,
478         // we skip recursing down that edge, to avoid requiring nodes that
479         // may not be needed (note that the input node may still be added
480         // to name_queue later if one of its output edges is not being fed).
481       }
482     }
483   }
484 
485   // Copy over, preserving order of original and only nodes that are reachable
486   // from the fetches.
487   out->mutable_node()->Reserve(in.node_size());
488   for (const NodeDef& node : in.node()) {
489     if (node_by_name[node.name()].first) {
490       *out->add_node() = node;
491     }
492   }
493   return Status::OK();
494 }
495 
TensorIdToString(const tf2xla::TensorId & id)496 string TensorIdToString(const tf2xla::TensorId& id) {
497   return absl::StrCat(id.node_name(), ":", id.output_index());
498 }
499 
SetNodeShardingFromNeighbors(Node * n,bool out_edges)500 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
501   int core = -1;
502   const Node* matching_node = nullptr;
503   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
504     if (edge->IsControlEdge()) continue;
505     const Node* possible_match = out_edges ? edge->dst() : edge->src();
506     TF_ASSIGN_OR_RETURN(
507         absl::optional<xla::OpSharding> sharding,
508         ParseShardingFromDevice(
509             *possible_match,
510             /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
511             /*add_metadata=*/false));
512     if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
513       const int core_annotation = sharding.value().tile_assignment_devices(0);
514       if (core == -1 || core > core_annotation) {
515         core = core_annotation;
516         matching_node = possible_match;
517       }
518     }
519   }
520   if (matching_node != nullptr) {
521     n->set_assigned_device_name(matching_node->assigned_device_name());
522     n->set_requested_device(matching_node->requested_device());
523   }
524   return Status::OK();
525 }
526 
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)527 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
528                                    KernelDef* kdef) {
529   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
530     if (constraint.name() == name) {
531       constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
532     }
533   }
534 }
535 
536 namespace {
InitialRandomSeed()537 uint32 InitialRandomSeed() {
538   // Support plumbing the TF seed through to XLA is being worked on.
539   // If a user wants deterministic behavior, their best option
540   // is to start with a known checkpoint. This also handles issues when
541   // multiple random calls can be invoked in any order by TF executor.
542   // Another option is to use stateless random ops. They have much cleaner
543   // semantics.
544   // If a user really wants to set a deterministic seed for XLA-based
545   // devices, this is the place to do it.
546   std::random_device rd;
547   // Make the starting value odd.
548   return rd() | 1;
549 }
550 }  // namespace
551 
GetXLARandomSeed()552 uint32 GetXLARandomSeed() {
553   // We initialize counter with an odd number and increment it by two
554   // everytime. This ensures that it will never be zero, even
555   // after an overflow. When seeded with zero, some XLA backends
556   // can return all zeros instead of random numbers.
557   static std::atomic<uint32> counter(InitialRandomSeed());
558   uint32 seed = counter.fetch_add(2);
559   std::srand(seed);
560   return std::rand() | 1;
561 }
562 
563 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)564 bool HasAssociatedFunction(const NodeDef& node_def,
565                            const FunctionLibraryDefinition* fld) {
566   if (fld->Contains(node_def.op())) {
567     return true;
568   }
569 
570   if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
571     // Gradient op has "f" attr, which is set to the function we are getting
572     // gradient for. We need to functionalize the gradient function.
573     return true;
574   }
575 
576   if (node_def.op() == "XlaHostCompute") {
577     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
578     // related to graph execution.
579     return false;
580   }
581 
582   for (const auto& iter : node_def.attr()) {
583     if (iter.second.has_func()) {
584       return true;
585     }
586   }
587 
588   return false;
589 }
590 
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)591 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
592     const Node& node, const FunctionLibraryDefinition* fld) {
593   std::vector<AssociatedFunctionInfo> results;
594   const string& op = node.type_string();
595   if (fld->Contains(op)) {
596     // This is a function call node.
597     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
598     results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
599   } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
600     // This is a SymbolicGradient op.
601     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
602     results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
603   } else if (node.type_string() == "XlaHostCompute") {
604     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
605     // related to graph execution.
606   } else {
607     // Collect all function attrs for the node.
608     for (auto& iter : node.attrs()) {
609       if (iter.second.has_func()) {
610         VLOG(2) << "Found function attr for node " << node.name() << ": "
611                 << iter.first << " = " << iter.second.func().name();
612         results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
613             iter.second.func().name(), iter.second.func().attr(), iter.first));
614       }
615     }
616   }
617   return results;
618 }
619 
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)620 Status RewriteAssociatedFunction(
621     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
622     const AssociatedFunctionInfo& associated_function,
623     const string& rewritten_function_name) {
624   switch (associated_function.type()) {
625     case AssociatedFunctionInfo::kFunctionCallNode: {
626       // Change this node to call the new function.
627       NodeDebugInfo debug_info(*node);
628       NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
629                              &debug_info);
630       for (const auto& attr : node->attrs()) {
631         builder.Attr(attr.first, attr.second);
632       }
633       for (int i = 0; i < node->num_inputs(); i++) {
634         Node* input_node;
635         TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
636         builder.Input(input_node->name(), i, node->input_type(i));
637       }
638       builder.Device(node->assigned_device_name().empty()
639                          ? node->requested_device()
640                          : node->assigned_device_name());
641       NodeDef node_def;
642       TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
643       Status s;
644       Node* new_node = graph->AddNode(node_def, &s);
645       TF_RETURN_IF_ERROR(s);
646       for (auto edge : node->in_edges()) {
647         graph->AddEdge(edge->src(), edge->src_output(), new_node,
648                        edge->dst_input());
649       }
650       for (auto edge : node->out_edges()) {
651         graph->AddEdge(new_node, edge->src_output(), edge->dst(),
652                        edge->dst_input());
653       }
654       graph->RemoveNode(node);
655       break;
656     }
657     case AssociatedFunctionInfo::kSymbolicGradient: {
658       NameAttrList func;
659       TF_RETURN_IF_ERROR(GetNodeAttr(
660           node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
661       GradientDef gradient_def;
662       gradient_def.set_function_name(func.name());
663       gradient_def.set_gradient_func(rewritten_function_name);
664       string original_grad_func = fld->FindGradient(func.name());
665       if (original_grad_func.empty()) {
666         TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
667       } else if (original_grad_func != rewritten_function_name) {
668         TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
669       }
670       break;
671     }
672     case AssociatedFunctionInfo::kFunctionAttr: {
673       // Change function attr to rewritten functions.
674       NameAttrList func;
675       TF_RETURN_IF_ERROR(
676           GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
677       node->ClearAttr(associated_function.attr_name());
678       func.set_name(rewritten_function_name);
679       node->AddAttr(associated_function.attr_name(), func);
680       break;
681     }
682   }
683 
684   return Status::OK();
685 }
686 
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)687 Status CachedFunctionHandles::GetOrInstantiate(
688     const string& func_name, AttrSlice attrs,
689     FunctionLibraryRuntime::Handle* handle) {
690   string canonicalized_name = Canonicalize(func_name, attrs);
691   auto iter = handles_.find(canonicalized_name);
692   if (iter != handles_.end()) {
693     *handle = iter->second;
694     return Status::OK();
695   }
696 
697   TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
698   handles_[canonicalized_name] = *handle;
699   return Status::OK();
700 }
701 
ReleaseAllHandles()702 Status CachedFunctionHandles::ReleaseAllHandles() {
703   Status result;
704   for (const auto& iter : handles_) {
705     result.Update(flr_->ReleaseHandle(iter.second));
706   }
707   handles_.clear();
708   return result;
709 }
710 
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)711 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
712   // Create the replacement node.
713   Status s;
714   Node* new_node = g->AddNode(node_def, &s);
715   if (!s.ok()) {
716     return s;
717   }
718 
719   // Record original node's output edges and remove them first. This is to avoid
720   // multiple producers for dst nodes' input.
721   std::vector<OutEdgeInfo> out_edge_info;
722   std::vector<const Edge*> out_edges;
723   for (const Edge* edge : n->out_edges()) {
724     out_edges.push_back(edge);
725     out_edge_info.push_back(
726         {edge->dst(), edge->src_output(), edge->dst_input()});
727   }
728   for (const Edge* edge : out_edges) {
729     g->RemoveEdge(edge);
730   }
731 
732   // Add original node's input and output edges to the replacement node.
733   for (const Edge* in_edge : n->in_edges()) {
734     g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
735                in_edge->dst_input());
736   }
737   for (const OutEdgeInfo& out_edge : out_edge_info) {
738     g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
739   }
740 
741   // Remove the original node.
742   g->RemoveNode(n);
743 
744   return new_node;
745 }
746 
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)747 xla::StatusOr<Node*> BuildIdentityNode(
748     Graph* graph, const string& node_name, DataType dtype, const Node* input,
749     absl::optional<string> requested_device) {
750   // Create identity node.
751   NodeDef ndef;
752   ndef.set_name(node_name);
753   ndef.set_op("Identity");
754   if (input) {
755     ndef.add_input(input->name());
756   }
757   if (requested_device) {
758     ndef.set_device(*requested_device);
759   }
760   AddNodeAttr("T", dtype, &ndef);
761   Status s;
762   Node* id_node = graph->AddNode(ndef, &s);
763   TF_RETURN_IF_ERROR(s);
764   return id_node;
765 }
766 
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)767 Status PropagateConstIntoFunctionalNodes(
768     Graph* g, const FunctionLibraryDefinition* lookup_fld,
769     FunctionLibraryDefinition* fld) {
770   for (Node* n : g->op_nodes()) {
771     if (n->IsIfNode()) {
772       TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
773     } else if (n->IsWhileNode()) {
774       TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
775     }
776   }
777   return Status::OK();
778 }
779 
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)780 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
781                                           FunctionLibraryDefinition* fld) {
782   GraphDef graph_def;
783   g.ToGraphDef(&graph_def);
784   FunctionLibraryDefinition reachable_functions =
785       fld->ReachableDefinitions(graph_def);
786   for (const string& func_name : fld->ListFunctionNames()) {
787     if (!reachable_functions.Find(func_name)) {
788       TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
789     }
790   }
791   return Status::OK();
792 }
793 
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)794 Status RewriteTensorListWithConstElement(Graph* g,
795                                          FunctionLibraryDefinition* fld) {
796   for (Node* n : g->nodes()) {
797     if (n->type_string() != "EmptyTensorList") {
798       continue;
799     }
800 
801     // Find the forward While op.
802     std::vector<const Edge*> fwd_while_edges;
803     for (const Edge* e : n->out_edges()) {
804       if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
805         fwd_while_edges.push_back(e);
806       }
807     }
808     if (fwd_while_edges.size() != 1) {
809       // No forward While op found, or multiple forward While ops.
810       continue;
811     }
812 
813     // Find the backward While op.
814     Node* fwd_while = fwd_while_edges[0]->dst();
815     int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
816     std::vector<const Edge*> bwd_while_edges;
817     for (const Edge* e : fwd_while->out_edges()) {
818       if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
819         bwd_while_edges.push_back(e);
820       }
821     }
822     if (bwd_while_edges.size() != 1) {
823       // No backward While op found, or multiple backward While ops.
824       continue;
825     }
826 
827     Node* bwd_while = bwd_while_edges[0]->dst();
828     int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
829 
830     // Look into forward While body function and check if TensorListPushBack op
831     // has a Const input.
832     NameAttrList fwd_body_attr;
833     TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
834     const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
835     if (!fwd_body) {
836       return errors::InvalidArgument("Cannot find function ",
837                                      fwd_body_attr.name(), " for While node ",
838                                      fwd_while->DebugString());
839     }
840     std::unique_ptr<FunctionBody> fwd_fbody;
841     TF_CHECK_OK(FunctionDefToBodyHelper(
842         *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
843 
844     // Find the TensorListPushBack node; it's one of fwd_arg's successors.
845     Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
846     std::vector<Node*> tl_push_nodes;
847     for (const Edge* out_edge : fwd_arg->out_edges()) {
848       if (out_edge->dst()->type_string() == "TensorListPushBack") {
849         tl_push_nodes.push_back(out_edge->dst());
850       }
851     }
852     if (tl_push_nodes.size() != 1) {
853       // No TensorListPushBack found, or multiple TensorListPushBack.
854       continue;
855     }
856 
857     // Get input for the TensorListPushBack node.
858     Node* input_node;
859     TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
860     if (input_node->type_string() != "Const") {
861       // Input for the TensorList is not Const node.
862       continue;
863     }
864 
865     NodeDef const_input_nodedef = input_node->def();
866 
867     // Rewrite backward While body function, replace usages of
868     // TensorListPopBack with a Const node.
869     NameAttrList bwd_body_attr;
870     TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
871     const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
872     if (!bwd_body) {
873       return errors::InvalidArgument("Cannot find function ",
874                                      bwd_body_attr.name(), " for While node ",
875                                      bwd_while->DebugString());
876     }
877     std::unique_ptr<FunctionBody> bwd_fbody;
878     TF_CHECK_OK(FunctionDefToBodyHelper(
879         *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
880 
881     // Find the TensorListPopBack node; it's one of bwd_arg's successors.
882     Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
883     std::vector<Node*> tl_pop_nodes;
884     for (const Edge* out_edge : bwd_arg->out_edges()) {
885       if (out_edge->dst()->type_string() == "TensorListPopBack") {
886         tl_pop_nodes.push_back(out_edge->dst());
887       }
888     }
889     if (tl_pop_nodes.size() != 1) {
890       // No TensorListPopBack found, or multiple TensorListPopBack.
891       continue;
892     }
893 
894     // Replace TensorListPopBack usages with Const node.
895     std::vector<const Edge*> edges_to_replace;
896     for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
897       if (e->src_output() == 1) {
898         edges_to_replace.push_back(e);
899       }
900     }
901     if (edges_to_replace.empty()) {
902       continue;
903     }
904     Status s;
905     const_input_nodedef.set_name(
906         bwd_fbody->graph->NewName(const_input_nodedef.name()));
907     Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
908     TF_RETURN_IF_ERROR(s);
909     for (const Edge* e : edges_to_replace) {
910       Node* dst = e->dst();
911       int dst_input = e->dst_input();
912       bwd_fbody->graph->RemoveEdge(e);
913       bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
914     }
915 
916     // Add rewritten backward While body function.
917     FunctionDef new_fdef;
918     string new_name = fld->UniqueFunctionName(
919         absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
920     TF_RETURN_IF_ERROR(
921         GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
922     TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
923 
924     // Change backward While op to use the new body function.
925     bwd_body_attr.set_name(new_name);
926     bwd_while->ClearAttr("body");
927     bwd_while->AddAttr("body", bwd_body_attr);
928   }
929   return Status::OK();
930 }
931 
932 }  // namespace tensorflow
933