1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17 
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23 
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41 
42 namespace tensorflow {
43 
44 namespace {
45 
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47   if (id.node_name().empty()) {
48     return errors::InvalidArgument("TensorId node_name must be non-empty");
49   }
50   if (id.output_index() < 0) {
51     return errors::InvalidArgument("TensorId output_index must be positive");
52   }
53   return Status::OK();
54 }
55 
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57                            std::set<string>* names) {
58   if (!name.empty()) {
59     if (!names->insert(name).second) {
60       return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61     }
62   }
63   return Status::OK();
64 }
65 
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67                                    const std::set<string>& names) {
68   // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69   // since that will cause a collision in codegen symbols.
70   for (const string& name : names) {
71     const string name_data(name + "_data");
72     if (names.find(name_data) != names.end()) {
73       return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74                                      " and ", name_data);
75     }
76   }
77   return Status::OK();
78 }
79 
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83                                const FunctionLibraryDefinition* lookup_fld,
84                                FunctionLibraryDefinition* fld) {
85   for (Node* n : g->op_nodes()) {
86     for (const auto& associated_function :
87          GetAssociatedFunctions(*n, lookup_fld)) {
88       switch (associated_function.type()) {
89         case AssociatedFunctionInfo::kFunctionCallNode: {
90           const FunctionDef* fdef =
91               lookup_fld->Find(associated_function.func_name());
92           if (!fdef) {
93             return errors::Internal(
94                 "Cannot find function ", associated_function.func_name(),
95                 " for function call node ", n->DebugString());
96           }
97           TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98           break;
99         }
100         case AssociatedFunctionInfo::kSymbolicGradient:
101         case AssociatedFunctionInfo::kFunctionAttr:
102           break;
103       }
104     }
105   }
106   return Status::OK();
107 }
108 
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112     Graph* g,
113     const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114   // Collect all _Arg nodes.
115   std::unordered_map<int, Node*> arg_nodes;
116   for (Node* n : g->op_nodes()) {
117     if (n->IsArg()) {
118       int index;
119       TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120       arg_nodes[index] = n;
121     }
122   }
123 
124   for (const auto& iter : const_input_index_to_node) {
125     int arg_index = iter.first;
126     NodeDef const_def = iter.second->def();
127     const_def.set_name(g->NewName(const_def.name()));
128     Status s;
129     Node* const_node = g->AddNode(const_def, &s);
130     TF_RETURN_IF_ERROR(s);
131 
132     Node* arg_node = arg_nodes[arg_index];
133 
134     // Collect all usages of the _Arg node.
135     struct OutEdgeInfo {
136       int dst_node_id, dst_input;
137     };
138     std::vector<OutEdgeInfo> usages;
139     for (const Edge* e : arg_node->out_edges()) {
140       if (e->IsControlEdge()) {
141         continue;
142       }
143       usages.push_back({e->dst()->id(), e->dst_input()});
144     }
145 
146     for (int i = 0; i < usages.size(); i++) {
147       // Make a copy of `usage_node`, and change its input to const node.
148       Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149       NodeDef replace_def = usage_node->def();
150       *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151       TF_ASSIGN_OR_RETURN(Node * replace_node,
152                           ReplaceNode(g, usage_node, replace_def));
153       const Edge* usage_edge;
154       TF_RETURN_IF_ERROR(
155           replace_node->input_edge(usages[i].dst_input, &usage_edge));
156       g->RemoveEdge(usage_edge);
157       g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158 
159       // Later entries in `usages` might have `usage_node` as dst node, but
160       // `usage_node` is removed. Replace such entries with `replace_node`.
161       for (int j = i + 1; j < usages.size(); j++) {
162         if (usages[j].dst_node_id == usages[i].dst_node_id) {
163           usages[j].dst_node_id = replace_node->id();
164         }
165       }
166     }
167   }
168   return Status::OK();
169 }
170 
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175     Node* n, const string& attr_name,
176     const std::unordered_map<int, const Node*>& const_input_index_to_node,
177     const FunctionLibraryDefinition* lookup_fld,
178     FunctionLibraryDefinition* fld) {
179   // Instantiate the function.
180   NameAttrList func_attr;
181   TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182   const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183   if (!fdef) {
184     return errors::Internal("Cannot find function ", func_attr.name(),
185                             " for node ", n->name());
186   }
187   FunctionBody* fbody;
188   TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189       *fdef, AttrSlice(&func_attr.attr()), lookup_fld,
190       [lookup_fld](const string& op, const OpDef** sig) {
191         return lookup_fld->LookUpOpDef(op, sig);
192       },
193       &fbody));
194   std::unique_ptr<FunctionBody> fbody_deleter(fbody);
195 
196   // Rewrite _Arg usages with Const node.
197   Graph* func_graph = fbody->graph;
198   TF_RETURN_IF_ERROR(
199       ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
200 
201   // Save rewritten function.
202   FunctionDef replace_fdef;
203   string new_func_name =
204       fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
205   TF_RETURN_IF_ERROR(
206       GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
207   TF_RETURN_IF_ERROR(fld->AddFunctionDef(replace_fdef));
208 
209   // Change the node to use rewritten function.
210   func_attr.set_name(new_func_name);
211   n->ClearAttr(attr_name);
212   n->AddAttr(attr_name, func_attr);
213 
214   // Copy associated functions.
215   TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
216 
217   return Status::OK();
218 }
219 
220 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
221 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)222 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
223                                 const FunctionLibraryDefinition* lookup_fld,
224                                 FunctionLibraryDefinition* fld) {
225   // Notice that first input for If node is predicate; other inputs are function
226   // inputs.
227   std::unordered_map<int, const Node*> const_input_index_to_node;
228   for (int i = 1; i < if_node->num_inputs(); i++) {
229     const Node* input_node;
230     TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
231     if (input_node->type_string() == "Const") {
232       const_input_index_to_node[i - 1] = input_node;
233     }
234   }
235   if (const_input_index_to_node.empty()) {
236     return Status::OK();
237   }
238 
239   // Rewrite "then_branch" and "else_branch" function, replace usage of those
240   // _Arg nodes with corresponding const node.
241   for (const auto& attr_name :
242        std::vector<string>{"then_branch", "else_branch"}) {
243     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
244         if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
245   }
246 
247   return Status::OK();
248 }
249 
250 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
251 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)252 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
253                                    const FunctionLibraryDefinition* lookup_fld,
254                                    FunctionLibraryDefinition* fld) {
255   // For "While" node, we should only replace _Arg nodes which are loop
256   // invariants. For such _Arg nodes, the return value's input will come
257   // directly from the corresponding arg.
258   std::unordered_map<int, const Node*> const_input_index_to_node;
259   NameAttrList body_attr;
260   TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
261   const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
262   if (!body_func) {
263     return errors::Internal("Cannot find body function ", body_attr.name(),
264                             " for While node ", while_node->name());
265   }
266   for (int i = 0; i < while_node->num_inputs(); i++) {
267     const Node* input_node;
268     TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
269     if (input_node->type_string() != "Const") {
270       continue;
271     }
272 
273     // Check if i-th retval's input comes from i-th arg directly.
274     // For resource variable input of While nodes, TF2XLA convention is to place
275     // them at the end of all inputs (after all data inputs), and *not* return
276     // them. So number of While node inputs might be larger than number of its
277     // outputs.
278     if (i >= body_func->signature().output_arg_size()) {
279       continue;
280     }
281     const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
282     auto output_arg_input = body_func->ret().find(output_arg.name());
283     if (output_arg_input == body_func->ret().end()) {
284       return errors::Internal("Cannot find input for output arg ",
285                               output_arg.name(), " in function ",
286                               body_attr.name());
287     }
288     const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
289     if (output_arg_input->second != input_arg.name()) {
290       continue;
291     }
292 
293     const_input_index_to_node[i] = input_node;
294   }
295   if (const_input_index_to_node.empty()) {
296     return Status::OK();
297   }
298 
299   // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
300   // corresponding const node.
301   for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
302     TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
303         while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
304   }
305   return Status::OK();
306 }
307 
308 }  // namespace
309 
310 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
311 
ValidateConfig(const tf2xla::Config & config)312 Status ValidateConfig(const tf2xla::Config& config) {
313   std::set<string> names;
314   for (const tf2xla::Feed& feed : config.feed()) {
315     TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
316     TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
317     TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
318   }
319   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
320   names.clear();
321   for (const tf2xla::Fetch& fetch : config.fetch()) {
322     TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
323     TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
324   }
325   TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
326   if (config.fetch().empty()) {
327     return errors::InvalidArgument("fetches must be specified");
328   }
329   return Status::OK();
330 }
331 
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)332 Status AddPlaceholdersForFeeds(
333     const tf2xla::Config& config, const OpRegistryInterface* op_registry,
334     std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
335   struct PlaceholderInfo {
336     const tf2xla::Feed* feed = nullptr;  // point to Feed in <config>.
337     string placeholder_name;
338     DataType data_type = DT_INVALID;
339   };
340 
341   // Put each fed tensor into a map by name:port. A map is used for determinism
342   // when creating placeholders (genrules want deterministic output).
343   std::map<string, PlaceholderInfo> placeholder_info;
344   for (int i = 0; i < config.feed_size(); ++i) {
345     const tf2xla::Feed* feed = &config.feed(i);
346     const string name_port = TensorIdToString(feed->id());
347     PlaceholderInfo& info = placeholder_info[name_port];
348     info.feed = feed;
349     info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
350                                          "/", feed->id().node_name());
351     (*feed_remapping)[name_port] = info.placeholder_name;
352   }
353 
354   // Verify node exists and determine data type.
355   std::unordered_map<string, const NodeDef*> name_to_node;
356   for (int i = 0; i < graph_def->node_size(); ++i) {
357     name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
358   }
359   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
360     PlaceholderInfo& info = it->second;
361     const tf2xla::TensorId& feed_id = info.feed->id();
362 
363     // Find the existing node and determine data type.
364     auto node_it = name_to_node.find(feed_id.node_name());
365     if (node_it == name_to_node.end()) {
366       return errors::NotFound("Can't find feed node: ",
367                               TensorIdToString(feed_id));
368     }
369     const NodeDef* existing = node_it->second;
370 
371     if (info.feed->type() != DT_INVALID) {
372       info.data_type = info.feed->type();
373     } else {
374       // Build the node in order to infer its type.
375 
376       // Must first add default attrs as well, so do this in a copied GraphDef.
377       GraphDef gd;
378       *gd.mutable_versions() = graph_def->versions();
379       *gd.add_node() = *existing;
380       MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
381       TF_RETURN_IF_ERROR(
382           AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
383 
384       // Now build the node from the copied node def.
385       Graph g(op_registry);
386       g.set_versions(graph_def->versions());
387       Status status;
388       Node* feed_node = g.AddNode(gd.node(0), &status);
389       TF_RETURN_IF_ERROR(status);
390 
391       if (info.feed->id().output_index() < feed_node->num_outputs()) {
392         info.data_type =
393             BaseType(feed_node->output_type(info.feed->id().output_index()));
394       } else {
395         return errors::InvalidArgument(
396             "Invalid output_index ", info.feed->id().output_index(),
397             " for feed node ", info.feed->id().node_name());
398       }
399     }
400   }
401 
402   // Create placeholders. Note that we could avoid creating a placeholder for
403   // feeds which are already placeholders, but we omit that to avoid more cases
404   // in this code.
405   for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
406     const PlaceholderInfo& info = it->second;
407     // TODO(shikharagarwal): Add original node information.
408     NodeDef* d = graph_def->add_node();
409     d->set_name(info.placeholder_name);
410     d->set_op("PlaceholderV2");
411     auto& attr_map = *d->mutable_attr();
412     attr_map["dtype"].set_type(info.data_type);
413     *attr_map["shape"].mutable_shape() = info.feed->shape();
414   }
415 
416   // Rewrite references to the fed tensors to refer to the placeholder.
417   for (int i = 0; i < graph_def->node_size(); ++i) {
418     NodeDef* node_def = graph_def->mutable_node(i);
419     for (int j = 0; j < node_def->input_size(); ++j) {
420       auto id = ParseTensorName(node_def->input(j));
421       auto it = placeholder_info.find(id.ToString());
422       if (it != placeholder_info.end()) {
423         node_def->set_input(j, it->second.placeholder_name);
424       }
425     }
426   }
427 
428   return Status::OK();
429 }
430 
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)431 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
432                          GraphDef* out) {
433   *out = in;
434   out->clear_node();
435 
436   // Tensors needed for feeding.
437   std::set<std::pair<string, int>> feed_tensors;
438   for (const tf2xla::Feed& feed : config.feed()) {
439     feed_tensors.insert(
440         std::make_pair(feed.id().node_name(), feed.id().output_index()));
441   }
442 
443   // Maps node name to reachability.
444   std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
445   for (const NodeDef& node : in.node()) {
446     node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
447   }
448 
449   // Traverse.
450   std::queue<string> name_queue;
451   for (int i = 0; i < config.fetch_size(); ++i) {
452     name_queue.push(config.fetch(i).id().node_name());
453   }
454   while (!name_queue.empty()) {
455     const string name = name_queue.front();
456     name_queue.pop();
457 
458     auto find_it = node_by_name.find(name);
459     if (find_it == node_by_name.end()) {
460       return errors::InvalidArgument("While pruning graph, node ", name,
461                                      " needed but not found in the graph.");
462     }
463     auto& map_entry = find_it->second;
464     if (map_entry.first) {
465       continue;
466     }
467     map_entry.first = true;
468 
469     // Push input nodes of the currently visited node to name_queue.
470     for (const string& in_edge : map_entry.second->input()) {
471       auto id = ParseTensorName(in_edge);
472       const string node_name = string(id.first);
473       if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
474           feed_tensors.end()) {
475         name_queue.push(node_name);
476       } else {
477         // The input tensor is from an edge that is being fed. Therefore,
478         // we skip recursing down that edge, to avoid requiring nodes that
479         // may not be needed (note that the input node may still be added
480         // to name_queue later if one of its output edges is not being fed).
481       }
482     }
483   }
484 
485   // Copy over, preserving order of original and only nodes that are reachable
486   // from the fetches.
487   out->mutable_node()->Reserve(in.node_size());
488   for (const NodeDef& node : in.node()) {
489     if (node_by_name[node.name()].first) {
490       *out->add_node() = node;
491     }
492   }
493   return Status::OK();
494 }
495 
TensorIdToString(const tf2xla::TensorId & id)496 string TensorIdToString(const tf2xla::TensorId& id) {
497   return absl::StrCat(id.node_name(), ":", id.output_index());
498 }
499 
SetNodeShardingFromNeighbors(Node * n,bool out_edges)500 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
501   int core = -1;
502   const Node* matching_node = nullptr;
503   for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
504     if (edge->IsControlEdge()) continue;
505     const Node* possible_match = out_edges ? edge->dst() : edge->src();
506     TF_ASSIGN_OR_RETURN(
507         absl::optional<xla::OpSharding> sharding,
508         ParseShardingFromDevice(
509             *possible_match,
510             /*num_cores_per_replica=*/std::numeric_limits<int32>::max()));
511     if (sharding.has_value()) {
512       TF_RET_CHECK(sharding.value().type() ==
513                    xla::OpSharding::Type::OpSharding_Type_MAXIMAL);
514       const int core_annotation = sharding.value().tile_assignment_devices(0);
515       if (core == -1 || core > core_annotation) {
516         core = core_annotation;
517         matching_node = possible_match;
518       }
519     }
520   }
521   if (matching_node != nullptr) {
522     n->set_assigned_device_name(matching_node->assigned_device_name());
523     n->set_requested_device(matching_node->requested_device());
524   }
525   return Status::OK();
526 }
527 
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)528 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
529                                    KernelDef* kdef) {
530   for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
531     if (constraint.name() == name) {
532       constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
533     }
534   }
535 }
536 
537 namespace {
InitialRandomSeed()538 uint32 InitialRandomSeed() {
539   // Support plumbing the TF seed through to XLA is being worked on.
540   // If a user wants deterministic behavior, their best option
541   // is to start with a known checkpoint. This also handles issues when
542   // multiple random calls can be invoked in any order by TF executor.
543   // Another option is to use stateless random ops. They have much cleaner
544   // semantics.
545   // If a user really wants to set a deterministic seed for XLA-based
546   // devices, this is the place to do it.
547   std::random_device rd;
548   // Make the starting value odd.
549   return rd() | 1;
550 }
551 }  // namespace
552 
GetXLARandomSeed()553 uint32 GetXLARandomSeed() {
554   // We initialize counter with an odd number and increment it by two
555   // everytime. This ensures that it will never be zero, even
556   // after an overflow. When seeded with zero, some XLA backends
557   // can return all zeros instead of random numbers.
558   static std::atomic<uint32> counter(InitialRandomSeed());
559   uint32 seed = counter.fetch_add(2);
560   std::srand(seed);
561   return std::rand() | 1;
562 }
563 
564 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)565 bool HasAssociatedFunction(const NodeDef& node_def,
566                            const FunctionLibraryDefinition* fld) {
567   if (fld->Contains(node_def.op())) {
568     return true;
569   }
570 
571   if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
572     // Gradient op has "f" attr, which is set to the function we are getting
573     // gradient for. We need to functionalize the gradient function.
574     return true;
575   }
576 
577   if (node_def.op() == "XlaHostCompute") {
578     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
579     // related to graph execution.
580     return false;
581   }
582 
583   for (const auto& iter : node_def.attr()) {
584     if (iter.second.has_func()) {
585       return true;
586     }
587   }
588 
589   return false;
590 }
591 
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)592 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
593     const Node& node, const FunctionLibraryDefinition* fld) {
594   std::vector<AssociatedFunctionInfo> results;
595   const string& op = node.type_string();
596   if (fld->Contains(op)) {
597     // This is a function call node.
598     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
599     results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
600   } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
601     // This is a SymbolicGradient op.
602     AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
603     results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
604   } else if (node.type_string() == "XlaHostCompute") {
605     // XlaHostCompute has "shape_inference_graph" func attr, but that's not
606     // related to graph execution.
607   } else {
608     // Collect all function attrs for the node.
609     for (auto& iter : node.attrs()) {
610       if (iter.second.has_func()) {
611         VLOG(2) << "Found function attr for node " << node.name() << ": "
612                 << iter.first << " = " << iter.second.func().name();
613         results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
614             iter.second.func().name(), iter.second.func().attr(), iter.first));
615       }
616     }
617   }
618   return results;
619 }
620 
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)621 Status RewriteAssociatedFunction(
622     Graph* graph, Node* node, FunctionLibraryDefinition* fld,
623     const AssociatedFunctionInfo& associated_function,
624     const string& rewritten_function_name) {
625   switch (associated_function.type()) {
626     case AssociatedFunctionInfo::kFunctionCallNode: {
627       // Change this node to call the new function.
628       NodeDebugInfo debug_info(*node);
629       NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
630                              &debug_info);
631       for (auto attr : node->attrs()) {
632         builder.Attr(attr.first, attr.second);
633       }
634       for (int i = 0; i < node->num_inputs(); i++) {
635         Node* input_node;
636         TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
637         builder.Input(input_node->name(), i, node->input_type(i));
638       }
639       builder.Device(node->assigned_device_name().empty()
640                          ? node->requested_device()
641                          : node->assigned_device_name());
642       NodeDef node_def;
643       TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
644       Status s;
645       Node* new_node = graph->AddNode(node_def, &s);
646       TF_RETURN_IF_ERROR(s);
647       for (auto edge : node->in_edges()) {
648         graph->AddEdge(edge->src(), edge->src_output(), new_node,
649                        edge->dst_input());
650       }
651       for (auto edge : node->out_edges()) {
652         graph->AddEdge(new_node, edge->src_output(), edge->dst(),
653                        edge->dst_input());
654       }
655       graph->RemoveNode(node);
656       break;
657     }
658     case AssociatedFunctionInfo::kSymbolicGradient: {
659       NameAttrList func;
660       TF_RETURN_IF_ERROR(GetNodeAttr(
661           node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
662       GradientDef gradient_def;
663       gradient_def.set_function_name(func.name());
664       gradient_def.set_gradient_func(rewritten_function_name);
665       string original_grad_func = fld->FindGradient(func.name());
666       if (original_grad_func.empty()) {
667         TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
668       } else if (original_grad_func != rewritten_function_name) {
669         TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
670       }
671       break;
672     }
673     case AssociatedFunctionInfo::kFunctionAttr: {
674       // Change function attr to rewritten functions.
675       NameAttrList func;
676       TF_RETURN_IF_ERROR(
677           GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
678       node->ClearAttr(associated_function.attr_name());
679       func.set_name(rewritten_function_name);
680       node->AddAttr(associated_function.attr_name(), func);
681       break;
682     }
683   }
684 
685   return Status::OK();
686 }
687 
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)688 Status CachedFunctionHandles::GetOrInstantiate(
689     const string& func_name, AttrSlice attrs,
690     FunctionLibraryRuntime::Handle* handle) {
691   string canonicalized_name = Canonicalize(func_name, attrs);
692   auto iter = handles_.find(canonicalized_name);
693   if (iter != handles_.end()) {
694     *handle = iter->second;
695     return Status::OK();
696   }
697 
698   TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
699   handles_[canonicalized_name] = *handle;
700   return Status::OK();
701 }
702 
ReleaseAllHandles()703 Status CachedFunctionHandles::ReleaseAllHandles() {
704   Status result;
705   for (auto iter : handles_) {
706     result.Update(flr_->ReleaseHandle(iter.second));
707   }
708   handles_.clear();
709   return result;
710 }
711 
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)712 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
713   // Create the replacement node.
714   Status s;
715   Node* new_node = g->AddNode(node_def, &s);
716   if (!s.ok()) {
717     return s;
718   }
719 
720   // Record original node's output edges and remove them first. This is to avoid
721   // multiple producers for dst nodes' input.
722   std::vector<OutEdgeInfo> out_edge_info;
723   std::vector<const Edge*> out_edges;
724   for (const Edge* edge : n->out_edges()) {
725     out_edges.push_back(edge);
726     out_edge_info.push_back(
727         {edge->dst(), edge->src_output(), edge->dst_input()});
728   }
729   for (const Edge* edge : out_edges) {
730     g->RemoveEdge(edge);
731   }
732 
733   // Add original node's input and output edges to the replacement node.
734   for (const Edge* in_edge : n->in_edges()) {
735     g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
736                in_edge->dst_input());
737   }
738   for (const OutEdgeInfo& out_edge : out_edge_info) {
739     g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
740   }
741 
742   // Remove the original node.
743   g->RemoveNode(n);
744 
745   return new_node;
746 }
747 
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)748 xla::StatusOr<Node*> BuildIdentityNode(
749     Graph* graph, const string& node_name, DataType dtype, const Node* input,
750     absl::optional<string> requested_device) {
751   // Create identity node.
752   NodeDef ndef;
753   ndef.set_name(node_name);
754   ndef.set_op("Identity");
755   if (input) {
756     ndef.add_input(input->name());
757   }
758   if (requested_device) {
759     ndef.set_device(*requested_device);
760   }
761   AddNodeAttr("T", dtype, &ndef);
762   Status s;
763   Node* id_node = graph->AddNode(ndef, &s);
764   TF_RETURN_IF_ERROR(s);
765   return id_node;
766 }
767 
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)768 Status PropagateConstIntoFunctionalNodes(
769     Graph* g, const FunctionLibraryDefinition* lookup_fld,
770     FunctionLibraryDefinition* fld) {
771   for (Node* n : g->op_nodes()) {
772     if (n->type_string() == "If") {
773       TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
774     } else if (n->type_string() == "While") {
775       TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
776     }
777   }
778   return Status::OK();
779 }
780 
781 }  // namespace tensorflow
782