1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
17
18 #include <functional>
19 #include <queue>
20 #include <random>
21 #include <set>
22 #include <unordered_map>
23
24 #include "absl/strings/str_cat.h"
25 #include "tensorflow/compiler/tf2xla/sharding_util.h"
26 #include "tensorflow/compiler/tf2xla/tf2xla.pb.h"
27 #include "tensorflow/compiler/xla/xla_data.pb.h"
28 #include "tensorflow/core/common_runtime/function.h"
29 #include "tensorflow/core/framework/graph.pb.h"
30 #include "tensorflow/core/framework/graph_def_util.h"
31 #include "tensorflow/core/framework/graph_to_functiondef.h"
32 #include "tensorflow/core/framework/node_def.pb.h"
33 #include "tensorflow/core/framework/node_def_builder.h"
34 #include "tensorflow/core/framework/node_def_util.h"
35 #include "tensorflow/core/framework/tensor_shape.h"
36 #include "tensorflow/core/framework/tensor_shape.pb.h"
37 #include "tensorflow/core/framework/versions.pb.h"
38 #include "tensorflow/core/graph/tensor_id.h"
39 #include "tensorflow/core/lib/core/errors.h"
40 #include "tensorflow/core/lib/core/status.h"
41
42 namespace tensorflow {
43
44 namespace {
45
ValidateTensorId(const tf2xla::TensorId & id)46 Status ValidateTensorId(const tf2xla::TensorId& id) {
47 if (id.node_name().empty()) {
48 return errors::InvalidArgument("TensorId node_name must be non-empty");
49 }
50 if (id.output_index() < 0) {
51 return errors::InvalidArgument("TensorId output_index must be positive");
52 }
53 return Status::OK();
54 }
55
CheckNameDuplicates(const string & kind,const string & name,std::set<string> * names)56 Status CheckNameDuplicates(const string& kind, const string& name,
57 std::set<string>* names) {
58 if (!name.empty()) {
59 if (!names->insert(name).second) {
60 return errors::InvalidArgument("duplicate ", kind, " name: ", name);
61 }
62 }
63 return Status::OK();
64 }
65
CheckFeedFetchNameConflicts(const string & kind,const std::set<string> & names)66 Status CheckFeedFetchNameConflicts(const string& kind,
67 const std::set<string>& names) {
68 // We don't allow the feeds or fetches to contain both "foo" and "foo_data",
69 // since that will cause a collision in codegen symbols.
70 for (const string& name : names) {
71 const string name_data(name + "_data");
72 if (names.find(name_data) != names.end()) {
73 return errors::InvalidArgument("conflicting ", kind, " name: ", name,
74 " and ", name_data);
75 }
76 }
77 return Status::OK();
78 }
79
80 // For graph `g`, copy all function call nodes' FunctionDef from `lookup_fld` to
81 // `fld`. This is to ensure that `fld` can instantiate FunctionDef of graph `g`.
CopyAssociatedFunctions(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)82 Status CopyAssociatedFunctions(Graph* g,
83 const FunctionLibraryDefinition* lookup_fld,
84 FunctionLibraryDefinition* fld) {
85 for (Node* n : g->op_nodes()) {
86 for (const auto& associated_function :
87 GetAssociatedFunctions(*n, lookup_fld)) {
88 switch (associated_function.type()) {
89 case AssociatedFunctionInfo::kFunctionCallNode: {
90 const FunctionDef* fdef =
91 lookup_fld->Find(associated_function.func_name());
92 if (!fdef) {
93 return errors::Internal(
94 "Cannot find function ", associated_function.func_name(),
95 " for function call node ", n->DebugString());
96 }
97 TF_RETURN_IF_ERROR(fld->AddFunctionDef(*fdef));
98 break;
99 }
100 case AssociatedFunctionInfo::kSymbolicGradient:
101 case AssociatedFunctionInfo::kFunctionAttr:
102 break;
103 }
104 }
105 }
106 return Status::OK();
107 }
108
109 // For graph `g`, replaces _Arg nodes whose "index" attribute is in
110 // `const_input_index_to_node` with Const nodes.
ReplaceArgUsageWithConstNode(Graph * g,const std::unordered_map<int,const Node * > & const_input_index_to_node)111 Status ReplaceArgUsageWithConstNode(
112 Graph* g,
113 const std::unordered_map<int, const Node*>& const_input_index_to_node) {
114 // Collect all _Arg nodes.
115 std::unordered_map<int, Node*> arg_nodes;
116 for (Node* n : g->op_nodes()) {
117 if (n->IsArg()) {
118 int index;
119 TF_RETURN_IF_ERROR(GetNodeAttr(n->attrs(), "index", &index));
120 arg_nodes[index] = n;
121 }
122 }
123
124 for (const auto& iter : const_input_index_to_node) {
125 int arg_index = iter.first;
126 NodeDef const_def = iter.second->def();
127 const_def.set_name(g->NewName(const_def.name()));
128 Status s;
129 Node* const_node = g->AddNode(const_def, &s);
130 TF_RETURN_IF_ERROR(s);
131
132 Node* arg_node = arg_nodes[arg_index];
133
134 // Collect all usages of the _Arg node.
135 struct OutEdgeInfo {
136 int dst_node_id, dst_input;
137 };
138 std::vector<OutEdgeInfo> usages;
139 for (const Edge* e : arg_node->out_edges()) {
140 if (e->IsControlEdge()) {
141 continue;
142 }
143 usages.push_back({e->dst()->id(), e->dst_input()});
144 }
145
146 for (int i = 0, end = usages.size(); i < end; i++) {
147 // Make a copy of `usage_node`, and change its input to const node.
148 Node* usage_node = g->FindNodeId(usages[i].dst_node_id);
149 NodeDef replace_def = usage_node->def();
150 *replace_def.mutable_input(usages[i].dst_input) = const_node->name();
151 TF_ASSIGN_OR_RETURN(Node * replace_node,
152 ReplaceNode(g, usage_node, replace_def));
153 const Edge* usage_edge;
154 TF_RETURN_IF_ERROR(
155 replace_node->input_edge(usages[i].dst_input, &usage_edge));
156 g->RemoveEdge(usage_edge);
157 g->AddEdge(const_node, 0, replace_node, usages[i].dst_input);
158
159 // Later entries in `usages` might have `usage_node` as dst node, but
160 // `usage_node` is removed. Replace such entries with `replace_node`.
161 for (int j = i + 1, end = usages.size(); j < end; j++) {
162 if (usages[j].dst_node_id == usages[i].dst_node_id) {
163 usages[j].dst_node_id = replace_node->id();
164 }
165 }
166 }
167 }
168 return Status::OK();
169 }
170
171 // For a node's function attr (e.g. then/else branch for "If" nodes), rewrites
172 // the function to replace _Arg nodes in `const_input_index_to_node` with Const
173 // inputs.
PropagateConstIntoFuncAttr(Node * n,const string & attr_name,const std::unordered_map<int,const Node * > & const_input_index_to_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)174 Status PropagateConstIntoFuncAttr(
175 Node* n, const string& attr_name,
176 const std::unordered_map<int, const Node*>& const_input_index_to_node,
177 const FunctionLibraryDefinition* lookup_fld,
178 FunctionLibraryDefinition* fld) {
179 // Instantiate the function.
180 NameAttrList func_attr;
181 TF_RETURN_IF_ERROR(GetNodeAttr(n->def(), attr_name, &func_attr));
182 const FunctionDef* fdef = lookup_fld->Find(func_attr.name());
183 if (!fdef) {
184 return errors::Internal("Cannot find function ", func_attr.name(),
185 " for node ", n->name());
186 }
187 std::unique_ptr<FunctionBody> fbody;
188 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
189 *fdef, AttrSlice(&func_attr.attr()), lookup_fld, &fbody));
190
191 // Rewrite _Arg usages with Const node.
192 Graph* func_graph = fbody->graph;
193 TF_RETURN_IF_ERROR(
194 ReplaceArgUsageWithConstNode(func_graph, const_input_index_to_node));
195
196 // Save rewritten function.
197 FunctionDef replace_fdef;
198 string new_func_name =
199 fld->UniqueFunctionName(absl::StrCat(func_attr.name(), "_const_"));
200 TF_RETURN_IF_ERROR(
201 GraphToFunctionDef(*func_graph, new_func_name, &replace_fdef));
202 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
203 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
204
205 // Change the node to use rewritten function.
206 func_attr.set_name(new_func_name);
207 n->ClearAttr(attr_name);
208 n->AddAttr(attr_name, func_attr);
209
210 TF_RETURN_IF_ERROR(fld->AddFunctionDef(
211 replace_fdef, lookup_fld->GetStackTraces(func_attr.name())));
212
213 // Copy associated functions.
214 TF_RETURN_IF_ERROR(CopyAssociatedFunctions(func_graph, lookup_fld, fld));
215
216 return Status::OK();
217 }
218
219 // For an "If" node in graph `g`, if it has Const node inputs, rewrite its
220 // then/else branch function to replace _Arg nodes with those Const inputs.
PropagateConstIntoIfNode(Graph * g,Node * if_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)221 Status PropagateConstIntoIfNode(Graph* g, Node* if_node,
222 const FunctionLibraryDefinition* lookup_fld,
223 FunctionLibraryDefinition* fld) {
224 // Notice that first input for If node is predicate; other inputs are function
225 // inputs.
226 std::unordered_map<int, const Node*> const_input_index_to_node;
227 for (int i = 1; i < if_node->num_inputs(); i++) {
228 const Node* input_node;
229 TF_RETURN_IF_ERROR(if_node->input_node(i, &input_node));
230 if (input_node->type_string() == "Const") {
231 const_input_index_to_node[i - 1] = input_node;
232 }
233 }
234 if (const_input_index_to_node.empty()) {
235 return Status::OK();
236 }
237
238 // Rewrite "then_branch" and "else_branch" function, replace usage of those
239 // _Arg nodes with corresponding const node.
240 for (const auto& attr_name :
241 std::vector<string>{"then_branch", "else_branch"}) {
242 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
243 if_node, attr_name, const_input_index_to_node, lookup_fld, fld));
244 }
245
246 return Status::OK();
247 }
248
249 // For a "While" node in graph `g`, if it has Const node inputs, rewrite its
250 // cond/body function to replace _Arg nodes with those Const inputs.
PropagateConstIntoWhileNode(Graph * g,Node * while_node,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)251 Status PropagateConstIntoWhileNode(Graph* g, Node* while_node,
252 const FunctionLibraryDefinition* lookup_fld,
253 FunctionLibraryDefinition* fld) {
254 // For "While" node, we should only replace _Arg nodes which are loop
255 // invariants. For such _Arg nodes, the return value's input will come
256 // directly from the corresponding arg.
257 std::unordered_map<int, const Node*> const_input_index_to_node;
258 NameAttrList body_attr;
259 TF_RETURN_IF_ERROR(GetNodeAttr(while_node->def(), "body", &body_attr));
260 const FunctionDef* body_func = lookup_fld->Find(body_attr.name());
261 if (!body_func) {
262 return errors::Internal("Cannot find body function ", body_attr.name(),
263 " for While node ", while_node->name());
264 }
265 for (int i = 0; i < while_node->num_inputs(); i++) {
266 const Node* input_node;
267 TF_RETURN_IF_ERROR(while_node->input_node(i, &input_node));
268 if (input_node->type_string() != "Const") {
269 continue;
270 }
271
272 // Check if i-th retval's input comes from i-th arg directly.
273 // For resource variable input of While nodes, TF2XLA convention is to place
274 // them at the end of all inputs (after all data inputs), and *not* return
275 // them. So number of While node inputs might be larger than number of its
276 // outputs.
277 if (i >= body_func->signature().output_arg_size()) {
278 continue;
279 }
280 const OpDef_ArgDef& output_arg = body_func->signature().output_arg(i);
281 auto output_arg_input = body_func->ret().find(output_arg.name());
282 if (output_arg_input == body_func->ret().end()) {
283 return errors::Internal("Cannot find input for output arg ",
284 output_arg.name(), " in function ",
285 body_attr.name());
286 }
287 const OpDef_ArgDef& input_arg = body_func->signature().input_arg(i);
288 if (output_arg_input->second != input_arg.name()) {
289 continue;
290 }
291
292 const_input_index_to_node[i] = input_node;
293 }
294 if (const_input_index_to_node.empty()) {
295 return Status::OK();
296 }
297
298 // Rewrite "cond" and "body" function, replace usage of those _Arg nodes with
299 // corresponding const node.
300 for (const auto& attr_name : std::vector<string>{"cond", "body"}) {
301 TF_RETURN_IF_ERROR(PropagateConstIntoFuncAttr(
302 while_node, attr_name, const_input_index_to_node, lookup_fld, fld));
303 }
304 return Status::OK();
305 }
306
307 } // namespace
308
309 const char kTpuReplicateAttrName[] = "_tpu_replicate";
310 const char kXlaOutsideCompilationAttrName[] = "_xla_outside_compilation";
311
ValidateConfig(const tf2xla::Config & config)312 Status ValidateConfig(const tf2xla::Config& config) {
313 std::set<string> names;
314 for (const tf2xla::Feed& feed : config.feed()) {
315 TF_RETURN_IF_ERROR(ValidateTensorId(feed.id()));
316 TF_RETURN_IF_ERROR(TensorShape::IsValidShape(feed.shape()));
317 TF_RETURN_IF_ERROR(CheckNameDuplicates("feed", feed.name(), &names));
318 }
319 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("feed", names));
320 names.clear();
321 for (const tf2xla::Fetch& fetch : config.fetch()) {
322 TF_RETURN_IF_ERROR(ValidateTensorId(fetch.id()));
323 TF_RETURN_IF_ERROR(CheckNameDuplicates("fetch", fetch.name(), &names));
324 }
325 TF_RETURN_IF_ERROR(CheckFeedFetchNameConflicts("fetch", names));
326 if (config.fetch().empty()) {
327 return errors::InvalidArgument("fetches must be specified");
328 }
329 return Status::OK();
330 }
331
AddPlaceholdersForFeeds(const tf2xla::Config & config,const OpRegistryInterface * op_registry,std::unordered_map<string,string> * feed_remapping,GraphDef * graph_def)332 Status AddPlaceholdersForFeeds(
333 const tf2xla::Config& config, const OpRegistryInterface* op_registry,
334 std::unordered_map<string, string>* feed_remapping, GraphDef* graph_def) {
335 struct PlaceholderInfo {
336 const tf2xla::Feed* feed = nullptr; // point to Feed in <config>.
337 string placeholder_name;
338 DataType data_type = DT_INVALID;
339 };
340
341 // Put each fed tensor into a map by name:port. A map is used for determinism
342 // when creating placeholders (genrules want deterministic output).
343 std::map<string, PlaceholderInfo> placeholder_info;
344 for (int i = 0; i < config.feed_size(); ++i) {
345 const tf2xla::Feed* feed = &config.feed(i);
346 const string name_port = TensorIdToString(feed->id());
347 PlaceholderInfo& info = placeholder_info[name_port];
348 info.feed = feed;
349 info.placeholder_name = absl::StrCat("aot_feed_", feed->id().output_index(),
350 "/", feed->id().node_name());
351 (*feed_remapping)[name_port] = info.placeholder_name;
352 }
353
354 // Verify node exists and determine data type.
355 std::unordered_map<string, const NodeDef*> name_to_node;
356 for (int i = 0; i < graph_def->node_size(); ++i) {
357 name_to_node[graph_def->node(i).name()] = &graph_def->node(i);
358 }
359 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
360 PlaceholderInfo& info = it->second;
361 const tf2xla::TensorId& feed_id = info.feed->id();
362
363 // Find the existing node and determine data type.
364 auto node_it = name_to_node.find(feed_id.node_name());
365 if (node_it == name_to_node.end()) {
366 return errors::NotFound("Can't find feed node: ",
367 TensorIdToString(feed_id));
368 }
369 const NodeDef* existing = node_it->second;
370
371 if (info.feed->type() != DT_INVALID) {
372 info.data_type = info.feed->type();
373 } else {
374 // Build the node in order to infer its type.
375
376 // Must first add default attrs as well, so do this in a copied GraphDef.
377 GraphDef gd;
378 *gd.mutable_versions() = graph_def->versions();
379 *gd.add_node() = *existing;
380 MergeDebugInfo(NodeDebugInfo(*existing), gd.mutable_node(0));
381 TF_RETURN_IF_ERROR(
382 AddDefaultAttrsToGraphDef(&gd, *op_registry, 0 /*node_offset*/));
383
384 // Now build the node from the copied node def.
385 Graph g(op_registry);
386 g.set_versions(graph_def->versions());
387 Status status;
388 Node* feed_node = g.AddNode(gd.node(0), &status);
389 TF_RETURN_IF_ERROR(status);
390
391 if (info.feed->id().output_index() < feed_node->num_outputs()) {
392 info.data_type =
393 BaseType(feed_node->output_type(info.feed->id().output_index()));
394 } else {
395 return errors::InvalidArgument(
396 "Invalid output_index ", info.feed->id().output_index(),
397 " for feed node ", info.feed->id().node_name());
398 }
399 }
400 }
401
402 // Create placeholders. Note that we could avoid creating a placeholder for
403 // feeds which are already placeholders, but we omit that to avoid more cases
404 // in this code.
405 for (auto it = placeholder_info.begin(); it != placeholder_info.end(); ++it) {
406 const PlaceholderInfo& info = it->second;
407 // TODO(shikharagarwal): Add original node information.
408 NodeDef* d = graph_def->add_node();
409 d->set_name(info.placeholder_name);
410 d->set_op("Placeholder");
411 auto& attr_map = *d->mutable_attr();
412 attr_map["dtype"].set_type(info.data_type);
413 *attr_map["shape"].mutable_shape() = info.feed->shape();
414 }
415
416 // Rewrite references to the fed tensors to refer to the placeholder.
417 for (int i = 0; i < graph_def->node_size(); ++i) {
418 NodeDef* node_def = graph_def->mutable_node(i);
419 for (int j = 0; j < node_def->input_size(); ++j) {
420 auto id = ParseTensorName(node_def->input(j));
421 auto it = placeholder_info.find(id.ToString());
422 if (it != placeholder_info.end()) {
423 node_def->set_input(j, it->second.placeholder_name);
424 }
425 }
426 }
427
428 return Status::OK();
429 }
430
PruneGraphDefInto(const tf2xla::Config & config,const GraphDef & in,GraphDef * out)431 Status PruneGraphDefInto(const tf2xla::Config& config, const GraphDef& in,
432 GraphDef* out) {
433 *out = in;
434 out->clear_node();
435
436 // Tensors needed for feeding.
437 std::set<std::pair<string, int>> feed_tensors;
438 for (const tf2xla::Feed& feed : config.feed()) {
439 feed_tensors.insert(
440 std::make_pair(feed.id().node_name(), feed.id().output_index()));
441 }
442
443 // Maps node name to reachability.
444 std::unordered_map<string, std::pair<bool, const NodeDef*>> node_by_name;
445 for (const NodeDef& node : in.node()) {
446 node_by_name[node.name()] = std::pair<bool, const NodeDef*>(false, &node);
447 }
448
449 // Traverse.
450 std::queue<string> name_queue;
451 for (int i = 0; i < config.fetch_size(); ++i) {
452 name_queue.push(config.fetch(i).id().node_name());
453 }
454 while (!name_queue.empty()) {
455 const string name = name_queue.front();
456 name_queue.pop();
457
458 auto find_it = node_by_name.find(name);
459 if (find_it == node_by_name.end()) {
460 return errors::InvalidArgument("While pruning graph, node ", name,
461 " needed but not found in the graph.");
462 }
463 auto& map_entry = find_it->second;
464 if (map_entry.first) {
465 continue;
466 }
467 map_entry.first = true;
468
469 // Push input nodes of the currently visited node to name_queue.
470 for (const string& in_edge : map_entry.second->input()) {
471 auto id = ParseTensorName(in_edge);
472 const string node_name = string(id.first);
473 if (feed_tensors.find(std::make_pair(node_name, id.second)) ==
474 feed_tensors.end()) {
475 name_queue.push(node_name);
476 } else {
477 // The input tensor is from an edge that is being fed. Therefore,
478 // we skip recursing down that edge, to avoid requiring nodes that
479 // may not be needed (note that the input node may still be added
480 // to name_queue later if one of its output edges is not being fed).
481 }
482 }
483 }
484
485 // Copy over, preserving order of original and only nodes that are reachable
486 // from the fetches.
487 out->mutable_node()->Reserve(in.node_size());
488 for (const NodeDef& node : in.node()) {
489 if (node_by_name[node.name()].first) {
490 *out->add_node() = node;
491 }
492 }
493 return Status::OK();
494 }
495
TensorIdToString(const tf2xla::TensorId & id)496 string TensorIdToString(const tf2xla::TensorId& id) {
497 return absl::StrCat(id.node_name(), ":", id.output_index());
498 }
499
SetNodeShardingFromNeighbors(Node * n,bool out_edges)500 Status SetNodeShardingFromNeighbors(Node* n, bool out_edges) {
501 int core = -1;
502 const Node* matching_node = nullptr;
503 for (const Edge* edge : (out_edges ? n->out_edges() : n->in_edges())) {
504 if (edge->IsControlEdge()) continue;
505 const Node* possible_match = out_edges ? edge->dst() : edge->src();
506 TF_ASSIGN_OR_RETURN(
507 absl::optional<xla::OpSharding> sharding,
508 ParseShardingFromDevice(
509 *possible_match,
510 /*num_cores_per_replica=*/std::numeric_limits<int32>::max(),
511 /*add_metadata=*/false));
512 if (sharding && sharding->type() == xla::OpSharding::MAXIMAL) {
513 const int core_annotation = sharding.value().tile_assignment_devices(0);
514 if (core == -1 || core > core_annotation) {
515 core = core_annotation;
516 matching_node = possible_match;
517 }
518 }
519 }
520 if (matching_node != nullptr) {
521 n->set_assigned_device_name(matching_node->assigned_device_name());
522 n->set_requested_device(matching_node->requested_device());
523 }
524 return Status::OK();
525 }
526
AddDtypeToKernelDefConstraint(absl::string_view name,DataType dtype,KernelDef * kdef)527 void AddDtypeToKernelDefConstraint(absl::string_view name, DataType dtype,
528 KernelDef* kdef) {
529 for (KernelDef::AttrConstraint& constraint : *kdef->mutable_constraint()) {
530 if (constraint.name() == name) {
531 constraint.mutable_allowed_values()->mutable_list()->add_type(dtype);
532 }
533 }
534 }
535
536 namespace {
InitialRandomSeed()537 uint32 InitialRandomSeed() {
538 // Support plumbing the TF seed through to XLA is being worked on.
539 // If a user wants deterministic behavior, their best option
540 // is to start with a known checkpoint. This also handles issues when
541 // multiple random calls can be invoked in any order by TF executor.
542 // Another option is to use stateless random ops. They have much cleaner
543 // semantics.
544 // If a user really wants to set a deterministic seed for XLA-based
545 // devices, this is the place to do it.
546 std::random_device rd;
547 // Make the starting value odd.
548 return rd() | 1;
549 }
550 } // namespace
551
GetXLARandomSeed()552 uint32 GetXLARandomSeed() {
553 // We initialize counter with an odd number and increment it by two
554 // everytime. This ensures that it will never be zero, even
555 // after an overflow. When seeded with zero, some XLA backends
556 // can return all zeros instead of random numbers.
557 static std::atomic<uint32> counter(InitialRandomSeed());
558 uint32 seed = counter.fetch_add(2);
559 std::srand(seed);
560 return std::rand() | 1;
561 }
562
563 // TODO(b/77601805): add tests for associated function related stuff.
HasAssociatedFunction(const NodeDef & node_def,const FunctionLibraryDefinition * fld)564 bool HasAssociatedFunction(const NodeDef& node_def,
565 const FunctionLibraryDefinition* fld) {
566 if (fld->Contains(node_def.op())) {
567 return true;
568 }
569
570 if (node_def.op() == FunctionLibraryDefinition::kGradientOp) {
571 // Gradient op has "f" attr, which is set to the function we are getting
572 // gradient for. We need to functionalize the gradient function.
573 return true;
574 }
575
576 if (node_def.op() == "XlaHostCompute") {
577 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
578 // related to graph execution.
579 return false;
580 }
581
582 for (const auto& iter : node_def.attr()) {
583 if (iter.second.has_func()) {
584 return true;
585 }
586 }
587
588 return false;
589 }
590
GetAssociatedFunctions(const Node & node,const FunctionLibraryDefinition * fld)591 std::vector<AssociatedFunctionInfo> GetAssociatedFunctions(
592 const Node& node, const FunctionLibraryDefinition* fld) {
593 std::vector<AssociatedFunctionInfo> results;
594 const string& op = node.type_string();
595 if (fld->Contains(op)) {
596 // This is a function call node.
597 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
598 results.emplace_back(AssociatedFunctionInfo::FunctionCall(op, attrs));
599 } else if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
600 // This is a SymbolicGradient op.
601 AttrValueMap attrs(node.attrs().begin(), node.attrs().end());
602 results.emplace_back(AssociatedFunctionInfo::SymbolicGradient(op, attrs));
603 } else if (node.type_string() == "XlaHostCompute") {
604 // XlaHostCompute has "shape_inference_graph" func attr, but that's not
605 // related to graph execution.
606 } else {
607 // Collect all function attrs for the node.
608 for (auto& iter : node.attrs()) {
609 if (iter.second.has_func()) {
610 VLOG(2) << "Found function attr for node " << node.name() << ": "
611 << iter.first << " = " << iter.second.func().name();
612 results.emplace_back(AssociatedFunctionInfo::FunctionAttr(
613 iter.second.func().name(), iter.second.func().attr(), iter.first));
614 }
615 }
616 }
617 return results;
618 }
619
RewriteAssociatedFunction(Graph * graph,Node * node,FunctionLibraryDefinition * fld,const AssociatedFunctionInfo & associated_function,const string & rewritten_function_name)620 Status RewriteAssociatedFunction(
621 Graph* graph, Node* node, FunctionLibraryDefinition* fld,
622 const AssociatedFunctionInfo& associated_function,
623 const string& rewritten_function_name) {
624 switch (associated_function.type()) {
625 case AssociatedFunctionInfo::kFunctionCallNode: {
626 // Change this node to call the new function.
627 NodeDebugInfo debug_info(*node);
628 NodeDefBuilder builder(node->name(), rewritten_function_name, fld,
629 &debug_info);
630 for (const auto& attr : node->attrs()) {
631 builder.Attr(attr.first, attr.second);
632 }
633 for (int i = 0; i < node->num_inputs(); i++) {
634 Node* input_node;
635 TF_RETURN_IF_ERROR(node->input_node(i, &input_node));
636 builder.Input(input_node->name(), i, node->input_type(i));
637 }
638 builder.Device(node->assigned_device_name().empty()
639 ? node->requested_device()
640 : node->assigned_device_name());
641 NodeDef node_def;
642 TF_RETURN_IF_ERROR(builder.Finalize(&node_def));
643 Status s;
644 Node* new_node = graph->AddNode(node_def, &s);
645 TF_RETURN_IF_ERROR(s);
646 for (auto edge : node->in_edges()) {
647 graph->AddEdge(edge->src(), edge->src_output(), new_node,
648 edge->dst_input());
649 }
650 for (auto edge : node->out_edges()) {
651 graph->AddEdge(new_node, edge->src_output(), edge->dst(),
652 edge->dst_input());
653 }
654 graph->RemoveNode(node);
655 break;
656 }
657 case AssociatedFunctionInfo::kSymbolicGradient: {
658 NameAttrList func;
659 TF_RETURN_IF_ERROR(GetNodeAttr(
660 node->attrs(), FunctionLibraryDefinition::kFuncAttr, &func));
661 GradientDef gradient_def;
662 gradient_def.set_function_name(func.name());
663 gradient_def.set_gradient_func(rewritten_function_name);
664 string original_grad_func = fld->FindGradient(func.name());
665 if (original_grad_func.empty()) {
666 TF_RETURN_IF_ERROR(fld->AddGradientDef(gradient_def));
667 } else if (original_grad_func != rewritten_function_name) {
668 TF_RETURN_IF_ERROR(fld->ReplaceGradient(gradient_def));
669 }
670 break;
671 }
672 case AssociatedFunctionInfo::kFunctionAttr: {
673 // Change function attr to rewritten functions.
674 NameAttrList func;
675 TF_RETURN_IF_ERROR(
676 GetNodeAttr(node->attrs(), associated_function.attr_name(), &func));
677 node->ClearAttr(associated_function.attr_name());
678 func.set_name(rewritten_function_name);
679 node->AddAttr(associated_function.attr_name(), func);
680 break;
681 }
682 }
683
684 return Status::OK();
685 }
686
GetOrInstantiate(const string & func_name,AttrSlice attrs,FunctionLibraryRuntime::Handle * handle)687 Status CachedFunctionHandles::GetOrInstantiate(
688 const string& func_name, AttrSlice attrs,
689 FunctionLibraryRuntime::Handle* handle) {
690 string canonicalized_name = Canonicalize(func_name, attrs);
691 auto iter = handles_.find(canonicalized_name);
692 if (iter != handles_.end()) {
693 *handle = iter->second;
694 return Status::OK();
695 }
696
697 TF_RETURN_IF_ERROR(flr_->Instantiate(func_name, attrs, handle));
698 handles_[canonicalized_name] = *handle;
699 return Status::OK();
700 }
701
ReleaseAllHandles()702 Status CachedFunctionHandles::ReleaseAllHandles() {
703 Status result;
704 for (const auto& iter : handles_) {
705 result.Update(flr_->ReleaseHandle(iter.second));
706 }
707 handles_.clear();
708 return result;
709 }
710
ReplaceNode(Graph * g,Node * n,const NodeDef & node_def)711 xla::StatusOr<Node*> ReplaceNode(Graph* g, Node* n, const NodeDef& node_def) {
712 // Create the replacement node.
713 Status s;
714 Node* new_node = g->AddNode(node_def, &s);
715 if (!s.ok()) {
716 return s;
717 }
718
719 // Record original node's output edges and remove them first. This is to avoid
720 // multiple producers for dst nodes' input.
721 std::vector<OutEdgeInfo> out_edge_info;
722 std::vector<const Edge*> out_edges;
723 for (const Edge* edge : n->out_edges()) {
724 out_edges.push_back(edge);
725 out_edge_info.push_back(
726 {edge->dst(), edge->src_output(), edge->dst_input()});
727 }
728 for (const Edge* edge : out_edges) {
729 g->RemoveEdge(edge);
730 }
731
732 // Add original node's input and output edges to the replacement node.
733 for (const Edge* in_edge : n->in_edges()) {
734 g->AddEdge(in_edge->src(), in_edge->src_output(), new_node,
735 in_edge->dst_input());
736 }
737 for (const OutEdgeInfo& out_edge : out_edge_info) {
738 g->AddEdge(new_node, out_edge.src_output, out_edge.dst, out_edge.dst_input);
739 }
740
741 // Remove the original node.
742 g->RemoveNode(n);
743
744 return new_node;
745 }
746
BuildIdentityNode(Graph * graph,const string & node_name,DataType dtype,const Node * input,absl::optional<string> requested_device)747 xla::StatusOr<Node*> BuildIdentityNode(
748 Graph* graph, const string& node_name, DataType dtype, const Node* input,
749 absl::optional<string> requested_device) {
750 // Create identity node.
751 NodeDef ndef;
752 ndef.set_name(node_name);
753 ndef.set_op("Identity");
754 if (input) {
755 ndef.add_input(input->name());
756 }
757 if (requested_device) {
758 ndef.set_device(*requested_device);
759 }
760 AddNodeAttr("T", dtype, &ndef);
761 Status s;
762 Node* id_node = graph->AddNode(ndef, &s);
763 TF_RETURN_IF_ERROR(s);
764 return id_node;
765 }
766
PropagateConstIntoFunctionalNodes(Graph * g,const FunctionLibraryDefinition * lookup_fld,FunctionLibraryDefinition * fld)767 Status PropagateConstIntoFunctionalNodes(
768 Graph* g, const FunctionLibraryDefinition* lookup_fld,
769 FunctionLibraryDefinition* fld) {
770 for (Node* n : g->op_nodes()) {
771 if (n->IsIfNode()) {
772 TF_RETURN_IF_ERROR(PropagateConstIntoIfNode(g, n, lookup_fld, fld));
773 } else if (n->IsWhileNode()) {
774 TF_RETURN_IF_ERROR(PropagateConstIntoWhileNode(g, n, lookup_fld, fld));
775 }
776 }
777 return Status::OK();
778 }
779
PruneUnreachableFunctionsFromGraph(const Graph & g,FunctionLibraryDefinition * fld)780 Status PruneUnreachableFunctionsFromGraph(const Graph& g,
781 FunctionLibraryDefinition* fld) {
782 GraphDef graph_def;
783 g.ToGraphDef(&graph_def);
784 FunctionLibraryDefinition reachable_functions =
785 fld->ReachableDefinitions(graph_def);
786 for (const string& func_name : fld->ListFunctionNames()) {
787 if (!reachable_functions.Find(func_name)) {
788 TF_RETURN_IF_ERROR(fld->RemoveFunction(func_name));
789 }
790 }
791 return Status::OK();
792 }
793
RewriteTensorListWithConstElement(Graph * g,FunctionLibraryDefinition * fld)794 Status RewriteTensorListWithConstElement(Graph* g,
795 FunctionLibraryDefinition* fld) {
796 for (Node* n : g->nodes()) {
797 if (n->type_string() != "EmptyTensorList") {
798 continue;
799 }
800
801 // Find the forward While op.
802 std::vector<const Edge*> fwd_while_edges;
803 for (const Edge* e : n->out_edges()) {
804 if (!e->IsControlEdge() && e->dst()->IsWhileNode()) {
805 fwd_while_edges.push_back(e);
806 }
807 }
808 if (fwd_while_edges.size() != 1) {
809 // No forward While op found, or multiple forward While ops.
810 continue;
811 }
812
813 // Find the backward While op.
814 Node* fwd_while = fwd_while_edges[0]->dst();
815 int fwd_while_dst_input = fwd_while_edges[0]->dst_input();
816 std::vector<const Edge*> bwd_while_edges;
817 for (const Edge* e : fwd_while->out_edges()) {
818 if (e->src_output() == fwd_while_dst_input && e->dst()->IsWhileNode()) {
819 bwd_while_edges.push_back(e);
820 }
821 }
822 if (bwd_while_edges.size() != 1) {
823 // No backward While op found, or multiple backward While ops.
824 continue;
825 }
826
827 Node* bwd_while = bwd_while_edges[0]->dst();
828 int bwd_while_dst_input = bwd_while_edges[0]->dst_input();
829
830 // Look into forward While body function and check if TensorListPushBack op
831 // has a Const input.
832 NameAttrList fwd_body_attr;
833 TF_CHECK_OK(GetNodeAttr(fwd_while->def(), "body", &fwd_body_attr));
834 const FunctionDef* fwd_body = fld->Find(fwd_body_attr.name());
835 if (!fwd_body) {
836 return errors::InvalidArgument("Cannot find function ",
837 fwd_body_attr.name(), " for While node ",
838 fwd_while->DebugString());
839 }
840 std::unique_ptr<FunctionBody> fwd_fbody;
841 TF_CHECK_OK(FunctionDefToBodyHelper(
842 *fwd_body, AttrSlice(&fwd_body_attr.attr()), fld, &fwd_fbody));
843
844 // Find the TensorListPushBack node; it's one of fwd_arg's successors.
845 Node* fwd_arg = fwd_fbody->arg_nodes[fwd_while_dst_input];
846 std::vector<Node*> tl_push_nodes;
847 for (const Edge* out_edge : fwd_arg->out_edges()) {
848 if (out_edge->dst()->type_string() == "TensorListPushBack") {
849 tl_push_nodes.push_back(out_edge->dst());
850 }
851 }
852 if (tl_push_nodes.size() != 1) {
853 // No TensorListPushBack found, or multiple TensorListPushBack.
854 continue;
855 }
856
857 // Get input for the TensorListPushBack node.
858 Node* input_node;
859 TF_CHECK_OK(tl_push_nodes[0]->input_node(1, &input_node));
860 if (input_node->type_string() != "Const") {
861 // Input for the TensorList is not Const node.
862 continue;
863 }
864
865 NodeDef const_input_nodedef = input_node->def();
866
867 // Rewrite backward While body function, replace usages of
868 // TensorListPopBack with a Const node.
869 NameAttrList bwd_body_attr;
870 TF_CHECK_OK(GetNodeAttr(bwd_while->def(), "body", &bwd_body_attr));
871 const FunctionDef* bwd_body = fld->Find(bwd_body_attr.name());
872 if (!bwd_body) {
873 return errors::InvalidArgument("Cannot find function ",
874 bwd_body_attr.name(), " for While node ",
875 bwd_while->DebugString());
876 }
877 std::unique_ptr<FunctionBody> bwd_fbody;
878 TF_CHECK_OK(FunctionDefToBodyHelper(
879 *bwd_body, AttrSlice(&bwd_body_attr.attr()), fld, &bwd_fbody));
880
881 // Find the TensorListPopBack node; it's one of bwd_arg's successors.
882 Node* bwd_arg = bwd_fbody->arg_nodes[bwd_while_dst_input];
883 std::vector<Node*> tl_pop_nodes;
884 for (const Edge* out_edge : bwd_arg->out_edges()) {
885 if (out_edge->dst()->type_string() == "TensorListPopBack") {
886 tl_pop_nodes.push_back(out_edge->dst());
887 }
888 }
889 if (tl_pop_nodes.size() != 1) {
890 // No TensorListPopBack found, or multiple TensorListPopBack.
891 continue;
892 }
893
894 // Replace TensorListPopBack usages with Const node.
895 std::vector<const Edge*> edges_to_replace;
896 for (const Edge* e : tl_pop_nodes[0]->out_edges()) {
897 if (e->src_output() == 1) {
898 edges_to_replace.push_back(e);
899 }
900 }
901 if (edges_to_replace.empty()) {
902 continue;
903 }
904 Status s;
905 const_input_nodedef.set_name(
906 bwd_fbody->graph->NewName(const_input_nodedef.name()));
907 Node* const_node = bwd_fbody->graph->AddNode(const_input_nodedef, &s);
908 TF_RETURN_IF_ERROR(s);
909 for (const Edge* e : edges_to_replace) {
910 Node* dst = e->dst();
911 int dst_input = e->dst_input();
912 bwd_fbody->graph->RemoveEdge(e);
913 bwd_fbody->graph->AddEdge(const_node, 0, dst, dst_input);
914 }
915
916 // Add rewritten backward While body function.
917 FunctionDef new_fdef;
918 string new_name = fld->UniqueFunctionName(
919 absl::StrCat(bwd_body_attr.name(), "_tl_rewrite_"));
920 TF_RETURN_IF_ERROR(
921 GraphToFunctionDef(*bwd_fbody->graph, new_name, &new_fdef));
922 TF_RETURN_IF_ERROR(fld->AddFunctionDef(new_fdef));
923
924 // Change backward While op to use the new body function.
925 bwd_body_attr.set_name(new_name);
926 bwd_while->ClearAttr("body");
927 bwd_while->AddAttr("body", bwd_body_attr);
928 }
929 return Status::OK();
930 }
931
932 } // namespace tensorflow
933