1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
17
18 #include <vector>
19
20 #include "absl/algorithm/container.h"
21 #include "absl/container/flat_hash_map.h"
22 #include "absl/container/flat_hash_set.h"
23 #include "absl/memory/memory.h"
24 #include "absl/strings/str_replace.h"
25 #include "absl/strings/substitute.h"
26 #include "tensorflow/compiler/jit/defs.h"
27 #include "tensorflow/core/common_runtime/device.h"
28 #include "tensorflow/core/common_runtime/device_mgr.h"
29 #include "tensorflow/core/common_runtime/device_set.h"
30 #include "tensorflow/core/common_runtime/function.h"
31 #include "tensorflow/core/common_runtime/graph_constructor.h"
32 #include "tensorflow/core/common_runtime/lower_case_op.h"
33 #include "tensorflow/core/common_runtime/lower_functional_ops.h"
34 #include "tensorflow/core/common_runtime/lower_if_op.h"
35 #include "tensorflow/core/common_runtime/lower_while_op.h"
36 #include "tensorflow/core/common_runtime/placer.h"
37 #include "tensorflow/core/framework/attr_value_util.h"
38 #include "tensorflow/core/framework/function.h"
39 #include "tensorflow/core/framework/function.pb.h"
40 #include "tensorflow/core/framework/graph_def_util.h"
41 #include "tensorflow/core/framework/node_def.pb.h"
42 #include "tensorflow/core/framework/node_def_util.h"
43 #include "tensorflow/core/framework/op_def.pb.h"
44 #include "tensorflow/core/framework/versions.pb.h"
45 #include "tensorflow/core/graph/algorithm.h"
46 #include "tensorflow/core/graph/control_flow.h"
47 #include "tensorflow/core/graph/graph_node_util.h"
48 #include "tensorflow/core/graph/tensor_id.h"
49 #include "tensorflow/core/grappler/graph_view.h"
50 #include "tensorflow/core/grappler/grappler_item.h"
51 #include "tensorflow/core/grappler/op_types.h"
52 #include "tensorflow/core/grappler/utils.h"
53 #include "tensorflow/core/grappler/utils/functions.h"
54 #include "tensorflow/core/lib/gtl/map_util.h"
55
56 namespace tensorflow {
57 namespace grappler {
58 namespace {
59
60 constexpr const char* const kFuncAttr = FunctionLibraryDefinition::kFuncAttr;
61
62 // Do not specialize functions marked with '_nospecialize' attribute.
63 constexpr const char* const kNoSpecializeAttr = "_nospecialize";
64
65 // Mark functions that were created as a result of function specialization.
66 constexpr const char* const kGrapplerSpecializedFuncAttr =
67 "_GrapplerSpecializedFunc";
68
69 // There are two ways of calling a Tensorflow function:
70 //
71 // 1. Direct function call: node.op() is the name of the function.
72 //
73 // 2. Indirect function call: the function name is passed through a node
74 // attribute, and special Tensorflow kernels are responsible for calling the
75 // function through the FunctionLibraryRuntime. Example: PartitionedCallOp.
76
77 // Check if func_node.op() matches the name in FunctionDef signature.
IsDirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)78 bool IsDirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
79 return func_node.op() == func.signature().name();
80 }
81
82 // Check if func_node has function attribute with a function name matching
83 // FunctionDef signature.
IsIndirectFunctionCall(const FunctionDef & func,const NodeDef & func_node)84 bool IsIndirectFunctionCall(const FunctionDef& func, const NodeDef& func_node) {
85 if (!IsPartitionedCall(func_node) && !IsStatefulPartitionedCall(func_node)) {
86 return false;
87 }
88
89 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
90 return func_attr != nullptr && func_attr->has_func() &&
91 func_attr->func().name() == func.signature().name();
92 }
93
FunctionInstantiationAttributes(const FunctionDef & func,const NodeDef & func_node)94 AttrSlice FunctionInstantiationAttributes(const FunctionDef& func,
95 const NodeDef& func_node) {
96 if (IsDirectFunctionCall(func, func_node)) {
97 return AttrSlice(func_node);
98
99 } else if (IsIndirectFunctionCall(func, func_node)) {
100 auto* func_attr = AttrSlice(func_node).Find(kFuncAttr);
101 return AttrSlice(&func_attr->func().attr());
102
103 } else {
104 LOG(WARNING) << "Can't resolve function instantiation attributes: "
105 << SummarizeNodeDef(func_node);
106 return AttrSlice();
107 }
108 }
109
110 // This is a fake device that should not be used for any op kernel execution,
111 // the only purpose of this device is to be passed as a part of DeviceSet to the
112 // Placer.
113 class FakeDevice : public Device {
114 public:
FakeDevice(Env * env,const string & device)115 FakeDevice(Env* env, const string& device) : Device(env, attr(device)) {}
FakeDevice(const string & device)116 explicit FakeDevice(const string& device) : FakeDevice(nullptr, device) {}
Sync()117 Status Sync() override { return Status::OK(); }
118
119 private:
attr(const string & device)120 static DeviceAttributes attr(const string& device) {
121 DeviceNameUtils::ParsedName parsed_name;
122 bool parsed = DeviceNameUtils::ParseFullName(device, &parsed_name);
123 DCHECK(parsed) << "Failed to parse full device name: " << device;
124
125 DeviceAttributes attr;
126 attr.set_name(device);
127 attr.set_device_type(parsed_name.type);
128 return attr;
129 }
130 };
131
132 // -------------------------------------------------------------------------- //
133 // Function specialization.
134 //
135 // FunctionDef is somewhat similar to function template in C++, given all the
136 // type parameters (and attribute values) it generates a statically defined
137 // graph from the type parametrized "graph template" (function body).
138 //
139 // Function specialization instantiates a parametrized FunctionDef into a
140 // statically defined graph, and then converts it back to the fully defined
141 // FunctionDef (it doesn't have any unknown type parameters or attribute
142 // values, known as placeholders).
143 //
144 // Given the fully specified graph we can apply all the Grappler optimizers to
145 // it (see details in MetaOptimizer). Also we can push known constant inputs
146 // into the function body, and remove unused outputs/inputs.
147
MarkedNoSpecialize(const FunctionDef & fdef)148 bool MarkedNoSpecialize(const FunctionDef& fdef) {
149 const auto attr = AttrSlice(&fdef.attr());
150 bool nospecialize = false;
151 return TryGetNodeAttr(attr, kNoSpecializeAttr, &nospecialize) && nospecialize;
152 }
153
154 // Specialized function instantiation type parameters, body parameters, and
155 // const inputs.
156 struct FunctionSpecializationSignature {
157 // Currently we do not support functions with tensor lists as inputs or
158 // outputs, so caller node input/output ports always match function
159 // input/output arguments.
160 using InputPort = int;
161 using OutputPort = int;
162
163 string func_name;
164 bool is_in_fetch_set;
165 absl::flat_hash_set<OutputPort> active_outputs;
166 absl::flat_hash_map<string, DataType> type_parameters;
167 absl::flat_hash_map<string, AttrValue> body_parameters;
168 absl::flat_hash_map<InputPort, string> const_inputs;
169
operator ==tensorflow::grappler::__anondd39c0ba0111::FunctionSpecializationSignature170 bool operator==(const FunctionSpecializationSignature& other) const {
171 bool equals = func_name == other.func_name &&
172 is_in_fetch_set == other.is_in_fetch_set &&
173 active_outputs == other.active_outputs &&
174 type_parameters == other.type_parameters &&
175 const_inputs == other.const_inputs;
176
177 if (!equals) return false;
178
179 // Equality is not defined for AttrValue.
180 if (body_parameters.size() != other.body_parameters.size()) return false;
181
182 for (const auto& lhs : body_parameters) {
183 auto it = other.body_parameters.find(lhs.first);
184 if (it == other.body_parameters.end()) return false;
185 if (!FastAreAttrValuesEqual(lhs.second, (*it).second)) return false;
186 }
187
188 return true;
189 }
190
191 template <typename H>
AbslHashValue(H h,const FunctionSpecializationSignature & s)192 friend H AbslHashValue(H h, const FunctionSpecializationSignature& s) {
193 H base = H::combine(std::move(h), s.func_name, s.is_in_fetch_set);
194
195 // First pre-compute hashes for all values in collections with
196 // non-deterministic iteration order.
197 std::vector<uint64> hashes;
198 hashes.reserve(s.active_outputs.size() //
199 + s.type_parameters.size() * 2 //
200 + s.body_parameters.size() * 2 //
201 + s.const_inputs.size() * 2);
202
203 absl::c_transform(s.active_outputs, std::back_inserter(hashes),
204 hash<OutputPort>());
205
206 using TypeParam = std::pair<const string, DataType>;
207 absl::c_for_each(s.type_parameters, [&hashes](const TypeParam& type_param) {
208 AttrValue attr_value;
209 attr_value.set_type(type_param.second);
210 hashes.push_back(Hash64(type_param.first));
211 hashes.push_back(AttrValueHash(attr_value));
212 });
213
214 using BodyParam = std::pair<const string, AttrValue>;
215 absl::c_for_each(s.body_parameters, [&hashes](const BodyParam& body_param) {
216 hashes.push_back(Hash64(body_param.first));
217 hashes.push_back(FastAttrValueHash(body_param.second));
218 });
219
220 using ConstInput = std::pair<const InputPort, string>;
221 absl::c_for_each(s.const_inputs, [&hashes](const ConstInput& const_input) {
222 hashes.push_back(hash<InputPort>()(const_input.first));
223 hashes.push_back(Hash64(const_input.second));
224 });
225
226 // Combine all pre-computed hashes in a deterministic order.
227 absl::c_sort(hashes);
228 return H::combine_contiguous(std::move(base), hashes.data(), hashes.size());
229 }
230 };
231
232 struct FunctionSpecialization {
233 string specialized_func_name;
234 // True if the function caller node is in GrapplerItem fetch set.
235 bool is_in_fetch_set;
236 // Names of the tensors that were pushed down into the function body.
237 absl::flat_hash_set<string> const_inputs;
238 // Control dependencies of pushed down const inputs have to be attached to
239 // function caller node.
240 absl::flat_hash_set<string> control_deps;
241 // Output tensors (ports) that consumed by other nodes in the graph or in a
242 // GrapplerItem fetch set.
243 absl::flat_hash_set<int> active_outputs;
244 // Mapping from original function output port to the output port of
245 // specialized function. If function specialization changes the number of
246 // function outputs it's required to update all node consumers.
247 std::vector<std::pair<int, int>> output_mapping;
248 };
249
250 // Function optimizer context initialized once for each optimization pass, and
251 // it uses the latest available graph (for the first iteration it will be the
252 // GrapplerItem.graph, for next iterations it will be the output of previous
253 // function optimizer pass).
254 class FunctionOptimizerContext {
255 public:
FunctionOptimizerContext(const GrapplerItem & item,RewriterConfig::Toggle opt_level,const GraphDef & graph)256 explicit FunctionOptimizerContext(const GrapplerItem& item,
257 RewriterConfig::Toggle opt_level,
258 const GraphDef& graph)
259 : item_(&item),
260 opt_level_(opt_level),
261 function_library_(OpRegistry::Global(), graph.library()),
262 truly_const_nodes_(InferTrulyConstNodes(item, graph)),
263 graph_view_(&graph) {}
264
item() const265 const GrapplerItem& item() const { return *item_; }
266
graph_version() const267 const int graph_version() const { return item_->graph.versions().producer(); }
268
opt_level() const269 RewriterConfig::Toggle opt_level() const { return opt_level_; }
270
function_library() const271 const FunctionLibraryDefinition& function_library() const {
272 return function_library_;
273 }
function_library()274 FunctionLibraryDefinition& function_library() { return function_library_; }
275
276 const absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>&
tensor_mapping() const277 tensor_mapping() const {
278 return tensor_mapping_;
279 }
280
graph_view() const281 const GraphView& graph_view() const { return graph_view_; }
282
IsFeedNode(const string & node_name) const283 bool IsFeedNode(const string& node_name) const {
284 return absl::c_any_of(
285 item_->feed, [&](const std::pair<std::string, Tensor>& feed) {
286 return ParseTensorName(feed.first).node() == node_name;
287 });
288 }
289
IsFetchNode(const string & node_name) const290 bool IsFetchNode(const string& node_name) const {
291 return absl::c_any_of(item_->fetch, [&](const string& fetch) {
292 return ParseTensorName(fetch).node() == node_name;
293 });
294 }
295
IsTrulyConst(const string & name) const296 bool IsTrulyConst(const string& name) const {
297 return TrulyConstNode(name) != nullptr;
298 }
299
TrulyConstNode(const string & name) const300 const NodeDef* TrulyConstNode(const string& name) const {
301 return gtl::FindWithDefault(truly_const_nodes_, name, nullptr);
302 }
303
FindFunctionSpecialization(const FunctionSpecializationSignature & sig) const304 const FunctionSpecialization* FindFunctionSpecialization(
305 const FunctionSpecializationSignature& sig) const {
306 return gtl::FindOrNull(specialized_functions_, sig);
307 }
308
AddSpecializedFunction(const FunctionSpecializationSignature & sig,const FunctionSpecialization & specialized_func)309 void AddSpecializedFunction(const FunctionSpecializationSignature& sig,
310 const FunctionSpecialization& specialized_func) {
311 specialized_functions_.emplace(sig, specialized_func);
312 }
313
AddTensorMapping(const SafeTensorId & from,const SafeTensorId & to)314 void AddTensorMapping(const SafeTensorId& from, const SafeTensorId& to) {
315 DCHECK(from.index() != Graph::kControlSlot)
316 << "Tensor mapping must be from regular tensor";
317 DCHECK(to.index() != Graph::kControlSlot)
318 << "Tensor mapping must be to regular tensor";
319
320 auto inserted = tensor_mapping_.insert({from, to});
321 DCHECK(inserted.second)
322 << "Failed to insert duplicated tensor mapping: "
323 << "from=" << from.ToString() << " to=" << to.ToString();
324 }
325
AddTensorMapping(const string & func_node,const FunctionSpecialization & specialized_func)326 void AddTensorMapping(const string& func_node,
327 const FunctionSpecialization& specialized_func) {
328 for (const auto& pair : specialized_func.output_mapping) {
329 int from_idx = pair.first;
330 int to_idx = pair.second;
331 if (from_idx != to_idx) {
332 SafeTensorId from_tensor(func_node, from_idx);
333 SafeTensorId to_tensor(func_node, to_idx);
334 AddTensorMapping(from_tensor, to_tensor);
335 }
336 }
337 }
338
339 private:
InferTrulyConstNodes(const GrapplerItem & item,const GraphDef & graph)340 static absl::flat_hash_map<string, const NodeDef*> InferTrulyConstNodes(
341 const GrapplerItem& item, const GraphDef& graph) {
342 absl::flat_hash_set<absl::string_view> feed_nodes;
343 for (const auto& feed : item.feed) {
344 feed_nodes.insert(feed.first);
345 }
346
347 absl::flat_hash_map<string, const NodeDef*> const_nodes;
348 for (const NodeDef& node : graph.node()) {
349 if (IsConstant(node) && !feed_nodes.contains(node.name())) {
350 const_nodes[node.name()] = &node;
351 }
352 }
353
354 return const_nodes;
355 }
356
357 const GrapplerItem* item_; // must outlive this object
358 RewriterConfig::Toggle opt_level_;
359
360 // Function library constructed from current graph.
361 FunctionLibraryDefinition function_library_;
362
363 // Nodes that are Const and not in feed.
364 absl::flat_hash_map<string, const NodeDef*> truly_const_nodes_;
365 // Specialized functions.
366 absl::flat_hash_map<FunctionSpecializationSignature,
367 const FunctionSpecialization>
368 specialized_functions_;
369
370 // After function specialization, the optimized graph might be in invalid
371 // state, nodes can read from output index that is no longer valid after
372 // unused outputs pruning.
373 //
374 // Tensor mapping that has to be applied to the graph after all functions
375 // optimizations (invalidated tensor id -> optimized graph tensor id).
376 absl::flat_hash_map<SafeTensorId, SafeTensorId, SafeTensorId::Hasher>
377 tensor_mapping_;
378
379 // Use graph view to find active outputs of the function caller nodes.
380 GraphView graph_view_;
381
382 TF_DISALLOW_COPY_AND_ASSIGN(FunctionOptimizerContext);
383 };
384
385 // Returns a pointer to the called function definition iff the given node is
386 // indeed a function call. Otherwise returns nullptr.
FindFunctionCall(const FunctionOptimizerContext & ctx,const NodeDef & node)387 const FunctionDef* FindFunctionCall(const FunctionOptimizerContext& ctx,
388 const NodeDef& node) {
389 // Check if a node does indirect function call via PartitionedCallOp.
390 if (IsPartitionedCall(node) || IsStatefulPartitionedCall(node)) {
391 const AttrValue* func_attr = AttrSlice(node).Find("f");
392 return (func_attr != nullptr && func_attr->has_func())
393 ? ctx.function_library().Find(func_attr->func().name())
394 : nullptr;
395 }
396
397 // Check if the function op itself is a function name.
398 return ctx.function_library().Find(node.op());
399 }
400
GetActiveOutputs(const NodeDef & node,const FunctionOptimizerContext & ctx,int size_hint=0)401 absl::flat_hash_set<int> GetActiveOutputs(const NodeDef& node,
402 const FunctionOptimizerContext& ctx,
403 int size_hint = 0) {
404 absl::flat_hash_set<int> active_outputs;
405 active_outputs.reserve(static_cast<size_t>(size_hint));
406
407 // 1. Output can be consumed by the other graph node.
408 const auto node_fanout_edges =
409 ctx.graph_view().GetFanoutEdges(node, /*include_controlled_edges=*/false);
410 for (const GraphView::Edge& edge : node_fanout_edges) {
411 active_outputs.insert(edge.src.port_id);
412 }
413
414 // 2. Or it can be in a fetch set.
415 for (const string& fetch : ctx.item().fetch) {
416 TensorId fetch_tensor = ParseTensorName(fetch);
417 if (fetch_tensor.node() == node.name()) {
418 active_outputs.insert(fetch_tensor.index());
419 }
420 }
421
422 return active_outputs;
423 }
424
HasTrulyConstInputs(const NodeDef & node,const FunctionOptimizerContext & ctx)425 bool HasTrulyConstInputs(const NodeDef& node,
426 const FunctionOptimizerContext& ctx) {
427 const auto is_truly_const = [&ctx](const string& input) {
428 return ctx.IsTrulyConst(NodeName(input));
429 };
430 return absl::c_any_of(node.input(), is_truly_const);
431 }
432
HasUnusedOutputs(const NodeDef & func_node,const FunctionDef & func,const FunctionOptimizerContext & ctx)433 bool HasUnusedOutputs(const NodeDef& func_node, const FunctionDef& func,
434 const FunctionOptimizerContext& ctx) {
435 // Functions with tensor list outputs are not supported right now, so the
436 // number of output args is the same as number of possible function caller
437 // node outputs.
438 int num_outputs = func.signature().output_arg_size();
439 const absl::flat_hash_set<int> active_outputs =
440 GetActiveOutputs(func_node, ctx, /*size_hind*/ num_outputs);
441 int active_outputs_size = active_outputs.size();
442 return active_outputs_size != num_outputs;
443 }
444
445 // Return pruned FunctionDefLibrary with functions that are reachable from
446 // the optimized graph.
PruneFunctionLibrary(const FunctionLibraryDefinition & flib,const GraphDef & optimized_graph)447 FunctionDefLibrary PruneFunctionLibrary(const FunctionLibraryDefinition& flib,
448 const GraphDef& optimized_graph) {
449 FunctionLibraryDefinition pruned_flib =
450 flib.ReachableDefinitions(optimized_graph);
451
452 int pruned_functions = static_cast<int>(pruned_flib.num_functions()) -
453 static_cast<int>(flib.num_functions());
454
455 VLOG(3) << "Pruned function library: " << pruned_flib.num_functions()
456 << " functions (" << pruned_functions << ")";
457
458 return pruned_flib.ToProto();
459 }
460
461 // Push all constant inputs of an instantiating node into the function body.
PushDownConstInputs(const NodeDef & func_node,const FunctionOptimizerContext & ctx,GrapplerFunctionItem * item,absl::flat_hash_set<string> * const_inputs,absl::flat_hash_set<string> * control_deps)462 Status PushDownConstInputs(const NodeDef& func_node,
463 const FunctionOptimizerContext& ctx,
464 GrapplerFunctionItem* item,
465 absl::flat_hash_set<string>* const_inputs,
466 absl::flat_hash_set<string>* control_deps) {
467 // Record node control dependencies in the control_deps set.
468 const auto record_control_deps = [&](const NodeDef* const_input) {
469 for (int i = const_input->input_size() - 1; i >= 0; --i) {
470 const string& input = const_input->input(i);
471 if (IsControlInput(input))
472 control_deps->insert(input);
473 else
474 break;
475 }
476 };
477
478 for (int i = func_node.input_size() - 1; i >= 0; --i) {
479 const string& input = func_node.input(i);
480 if (IsControlInput(input)) continue;
481
482 const string node_name = NodeName(input);
483 if (ctx.IsTrulyConst(node_name)) {
484 VLOG(3) << "Push const into function body: input=" << input;
485 const auto* const_input = CHECK_NOTNULL(ctx.TrulyConstNode(node_name));
486 const_inputs->insert(input);
487 record_control_deps(const_input);
488 TF_RETURN_IF_ERROR(ReplaceInputWithConst(*const_input, i, item));
489 }
490 }
491
492 return Status::OK();
493 }
494
495 // Remove inputs that were pushed into the function body, and attach their
496 // control dependencies to the function caller node.
RemovePushedDownConstInputs(const FunctionSpecialization & specialization,NodeDef * specialized_func_node)497 void RemovePushedDownConstInputs(const FunctionSpecialization& specialization,
498 NodeDef* specialized_func_node) {
499 // Nothing to do if it was no const inputs to the function node.
500 if (specialization.const_inputs.empty()) return;
501
502 // Keep only non-const inputs.
503 std::vector<string> keep_inputs;
504 const auto& inputs = specialized_func_node->input();
505 absl::c_copy_if(inputs, std::back_inserter(keep_inputs),
506 [&](const string& input) {
507 return !specialization.const_inputs.contains(input);
508 });
509
510 specialized_func_node->clear_input();
511 for (const auto& keep : keep_inputs) specialized_func_node->add_input(keep);
512
513 // Attach control dependencies of pushed down const input to the caller node.
514 if (!specialization.control_deps.empty()) {
515 absl::flat_hash_set<string> existing_control_deps;
516
517 for (const string& input : keep_inputs) {
518 existing_control_deps.insert(AsControlDependency(NodeName(input)));
519 }
520
521 for (const string& ctrl : specialization.control_deps) {
522 if (!existing_control_deps.contains(ctrl)) {
523 VLOG(3) << "Forward control dependency: input=" << ctrl;
524 specialized_func_node->add_input(ctrl);
525 }
526 }
527 }
528 }
529
530 // Remove Tin type parameters for pushed down const inputs.
RemovePushedDownConstInputTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)531 void RemovePushedDownConstInputTypes(
532 const FunctionSpecialization& specialization, const NodeDef& func_node,
533 NodeDef* specialized_func_node) {
534 // Nothing to do if it was no const inputs to the function node.
535 if (specialization.const_inputs.empty()) return;
536
537 // Make sure that original function caller has Tin attribute.
538 const AttrValue* tin = AttrSlice(func_node).Find("Tin");
539 if (tin == nullptr || !tin->has_list()) return;
540
541 // Clear input types for the specialized node.
542 auto* attr = specialized_func_node->mutable_attr();
543 (*attr)["Tin"].mutable_list()->clear_type();
544
545 // Keep types of non-const inputs.
546 for (int i = 0; i < func_node.input_size(); ++i) {
547 const string& input = func_node.input(i);
548 if (IsControlInput(input)) break;
549
550 if (!specialization.const_inputs.contains(input)) {
551 DataType dt = tin->list().type(i);
552 (*attr)["Tin"].mutable_list()->add_type(dt);
553 }
554 }
555 }
556
557 // Remove Tout type parameters for pruned function outputs.
RemoveUnusedOutputsTypes(const FunctionSpecialization & specialization,const NodeDef & func_node,NodeDef * specialized_func_node)558 void RemoveUnusedOutputsTypes(const FunctionSpecialization& specialization,
559 const NodeDef& func_node,
560 NodeDef* specialized_func_node) {
561 // Make sure that original function caller has Tout attribute.
562 const AttrValue* tout = AttrSlice(func_node).Find("Tout");
563 if (tout == nullptr || !tout->has_list()) return;
564
565 // Nothing to do if all outputs are active.
566 int specialization_active_outputs_size = specialization.active_outputs.size();
567 if (specialization_active_outputs_size == tout->list().type_size()) return;
568
569 // Clear input types for the specialized node.
570 auto* attr = specialized_func_node->mutable_attr();
571 (*attr)["Tout"].mutable_list()->clear_type();
572
573 // Keep output types of active outputs only.
574 for (int i = 0; i < tout->list().type_size(); ++i) {
575 if (specialization.active_outputs.contains(i)) {
576 DataType dt = tout->list().type(i);
577 (*attr)["Tout"].mutable_list()->add_type(dt);
578 }
579 }
580 }
581
UpdateSpecializedFunctionCallSite(const FunctionDef & func,const NodeDef & func_node,const string & specialized_func_name,NodeDef * specialized_func_node)582 Status UpdateSpecializedFunctionCallSite(const FunctionDef& func,
583 const NodeDef& func_node,
584 const string& specialized_func_name,
585 NodeDef* specialized_func_node) {
586 if (IsDirectFunctionCall(func, func_node)) {
587 specialized_func_node->set_op(specialized_func_name);
588
589 } else if (IsIndirectFunctionCall(func, func_node)) {
590 auto* attr = specialized_func_node->mutable_attr();
591 (*attr)[kFuncAttr].mutable_func()->set_name(specialized_func_name);
592
593 } else {
594 return errors::InvalidArgument("Unknown function call site");
595 }
596
597 return Status::OK();
598 }
599
600 // Update a graph node created from the original function caller node, to the
601 // function specialization. Function specialization might change the number of
602 // inputs and outputs, so we have to make sure that graph node is updated
603 // accordingly.
UpdateSpecializedFunctionNode(const FunctionDef & func,const NodeDef & func_node,const FunctionSpecialization & specialization,NodeDef * specialized_func_node)604 Status UpdateSpecializedFunctionNode(
605 const FunctionDef& func, const NodeDef& func_node,
606 const FunctionSpecialization& specialization,
607 NodeDef* specialized_func_node) {
608 // Function called indirectly via custom kernel (e.g. PartitionedCallOp).
609 bool is_indirect_call = IsIndirectFunctionCall(func, func_node);
610
611 // 1. Call the specialized function instead of original one.
612 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionCallSite(
613 func, func_node, specialization.specialized_func_name,
614 specialized_func_node));
615
616 // 2. Remove inputs corresponding to the pushed down consts.
617 RemovePushedDownConstInputs(specialization, specialized_func_node);
618
619 // NOTE: PartitionedCallOp has `Tin` and `Tout` attributes for input/output
620 // types, that must be in sync with updated function signature.
621
622 // 3. Update input types for the indirect function calls.
623 if (is_indirect_call) {
624 RemovePushedDownConstInputTypes(specialization, func_node,
625 specialized_func_node);
626 }
627
628 // 4. Update output types for the indirect function call. It's unsafe to
629 // change the number of outputs for the fetch nodes, so we just skip them.
630 if (is_indirect_call && !specialization.is_in_fetch_set) {
631 RemoveUnusedOutputsTypes(specialization, func_node, specialized_func_node);
632 }
633
634 // 5. Remove custom gradient annotation.
635 specialized_func_node->mutable_attr()->erase("_gradient_op_type");
636
637 return Status::OK();
638 }
639
InitializeFunctionSpecializationSignature(const NodeDef & func_node,const FunctionDef & func,const AttrSlice & func_instantiation_attr,const FunctionOptimizerContext & ctx,FunctionSpecializationSignature * sig)640 Status InitializeFunctionSpecializationSignature(
641 const NodeDef& func_node, const FunctionDef& func,
642 const AttrSlice& func_instantiation_attr,
643 const FunctionOptimizerContext& ctx, FunctionSpecializationSignature* sig) {
644 DCHECK(sig->const_inputs.empty());
645 DCHECK(sig->active_outputs.empty());
646
647 sig->func_name = func.signature().name();
648 sig->is_in_fetch_set = ctx.IsFetchNode(func_node.name());
649 sig->active_outputs = GetActiveOutputs(func_node, ctx);
650
651 TF_RETURN_IF_ERROR(InstantiationTypeParameters(func, func_instantiation_attr,
652 &sig->type_parameters));
653 TF_RETURN_IF_ERROR(InstantiationBodyParameters(func, func_instantiation_attr,
654 &sig->body_parameters));
655
656 for (int i = 0; i < func_node.input_size(); ++i) {
657 const string& input = func_node.input(i);
658 if (IsControlInput(input)) break;
659 if (ctx.IsTrulyConst(input)) {
660 sig->const_inputs.emplace(i, input);
661 }
662 }
663
664 return Status::OK();
665 }
666
667 // Create a name for the function specialization. The name of the function, name
668 // of the node instantiating it, and a Grappler item id should generate unique
669 // function name. Meta optimizer might create multiple Grappler items for the
670 // same graph when optimizing functions, but it's guaranteed that they all will
671 // have unique ids.
SpecializedFunctionName(const FunctionOptimizerContext & ctx,const FunctionDef & func,const NodeDef & func_node)672 string SpecializedFunctionName(const FunctionOptimizerContext& ctx,
673 const FunctionDef& func,
674 const NodeDef& func_node) {
675 return absl::Substitute(
676 "$0_specialized_for_$1_at_$2", func.signature().name(),
677 absl::StrReplaceAll(func_node.name(), {{"/", "_"}}), ctx.item().id);
678 }
679
SpecializeFunction(const NodeDef & func_node,const FunctionDef & func,FunctionOptimizerContext * ctx,GraphDef * optimized_graph)680 Status SpecializeFunction(const NodeDef& func_node, const FunctionDef& func,
681 FunctionOptimizerContext* ctx,
682 GraphDef* optimized_graph) {
683 VLOG(2) << "Specialize function call: " << SummarizeNodeDef(func_node);
684
685 const AttrSlice func_instantiation_attr =
686 FunctionInstantiationAttributes(func, func_node);
687
688 FunctionSpecializationSignature signature;
689 TF_RETURN_IF_ERROR(InitializeFunctionSpecializationSignature(
690 func_node, func, func_instantiation_attr, *ctx, &signature));
691
692 // Check if function was already specialized for identical context.
693 const FunctionSpecialization* already_specialized =
694 ctx->FindFunctionSpecialization(signature);
695
696 if (already_specialized) {
697 VLOG(2) << "Function was already specialized in identical context: "
698 "specialized_name="
699 << already_specialized->specialized_func_name;
700
701 // Add a function call node for the specialized function.
702 NodeDef* specialized_func_node = optimized_graph->add_node();
703 *specialized_func_node = func_node;
704
705 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
706 func, func_node, *already_specialized, specialized_func_node));
707
708 ctx->AddTensorMapping(specialized_func_node->name(), *already_specialized);
709
710 return Status::OK();
711 }
712
713 // Add a new specialized function definition to the library.
714 const auto& flib = ctx->function_library();
715
716 // Make a GrapplerFunctionItem and convert it back to FunctionDef after
717 // pushing all constant inputs into the function body.
718 GrapplerFunctionItem item;
719 TF_RETURN_IF_ERROR(MakeGrapplerFunctionItem(
720 func, func_instantiation_attr, flib, ctx->graph_version(), &item));
721
722 // Push const inputs into the function body, and keep track of their control
723 // dependencies.
724 absl::flat_hash_set<string> const_inputs;
725 absl::flat_hash_set<string> control_deps;
726 TF_RETURN_IF_ERROR(PushDownConstInputs(func_node, *ctx, &item, &const_inputs,
727 &control_deps));
728
729 // Remove function outputs that do not have any consumers. We can't safely
730 // update outputs for the fetch nodes, so we just skip them.
731 std::vector<std::pair<int, int>> output_mapping;
732 if (!signature.is_in_fetch_set) {
733 int num_func_outputs = item.output_size();
734
735 absl::flat_hash_set<int> remove;
736 for (int i = 0; i < num_func_outputs; ++i) {
737 if (!signature.active_outputs.count(i)) remove.insert(i);
738 }
739
740 TF_RETURN_IF_ERROR(RemoveFunctionOutputs(remove, &item, &output_mapping));
741 }
742
743 // TODO(ezhulenev): Push down known input shapes.
744 FunctionDef specialized_func;
745 TF_RETURN_IF_ERROR(MakeFunctionDef(item, flib, &specialized_func));
746
747 // Find a name for specialized function.
748 const string specialized_func_name =
749 SpecializedFunctionName(*ctx, func, func_node);
750 if (flib.Contains(specialized_func_name)) {
751 // NOTE(ezhulenev): This should never happen. If it happens, it's a sign of
752 // a serious internal error, that must be investigated.
753 return errors::Internal("Created duplicate function specialization");
754 }
755
756 specialized_func.mutable_signature()->set_name(specialized_func_name);
757 auto* specialized_attr = specialized_func.mutable_attr();
758 (*specialized_attr)[kGrapplerSpecializedFuncAttr].set_b(true);
759
760 // Add specialized function to the library.
761 TF_RETURN_IF_ERROR(ctx->function_library().AddFunctionDef(specialized_func));
762
763 // Add a function call node for the specialized function.
764 NodeDef* specialized_func_node = optimized_graph->add_node();
765 *specialized_func_node = func_node;
766
767 FunctionSpecialization func_specialization = {
768 specialized_func_name, signature.is_in_fetch_set, const_inputs,
769 control_deps, signature.active_outputs, output_mapping};
770
771 TF_RETURN_IF_ERROR(UpdateSpecializedFunctionNode(
772 func, func_node, func_specialization, specialized_func_node));
773
774 ctx->AddSpecializedFunction(signature, func_specialization);
775 ctx->AddTensorMapping(specialized_func_node->name(), func_specialization);
776
777 return Status::OK();
778 }
779
780 // -------------------------------------------------------------------------- //
781 // Inline function calls into a graph using function inlining implementation
782 // from common_runtime:
783 //
784 // 1) Convert GraphDef to Graph.
785 // 2) Inline function calls.
786 // 3) Convert Graph back to the GraphDef.
787
788 constexpr const char* const kLowerUsingSwitchMergeAttr =
789 LowerFunctionalOpsPass::kLowerUsingSwitchMergeAttr;
790 constexpr const char* const kLowerAsMultiDeviceFunctionAttr =
791 LowerFunctionalOpsPass::kLowerAsMultiDeviceFunctionAttr;
792
793 using KeepCallerNode = InlineFunctionBodyOptions::KeepCallerNode;
794 using OutputControlSource = InlineFunctionBodyOptions::OutputControlSource;
795
796 // Checks if boolean attribute is defined and its value is 'true'.
CheckBoolAttr(const Node * n,absl::string_view attr_name)797 bool CheckBoolAttr(const Node* n, absl::string_view attr_name) {
798 bool match;
799 bool found = TryGetNodeAttr(n->attrs(), attr_name, &match);
800 return found && match;
801 }
802
803 // Checks if string attribute is defined and it's not empty.
CheckStringAttr(const Node * n,absl::string_view attr_name)804 bool CheckStringAttr(const Node* n, absl::string_view attr_name) {
805 const string& value = GetNodeAttrString(n->attrs(), attr_name);
806 return !value.empty();
807 }
808
LowerUsingSwitchMergeIsOn(const Node * n)809 bool LowerUsingSwitchMergeIsOn(const Node* n) {
810 return CheckBoolAttr(n, kLowerUsingSwitchMergeAttr);
811 }
812
LowerAsMultiDeviceFunctionIsOn(const Node * n)813 bool LowerAsMultiDeviceFunctionIsOn(const Node* n) {
814 return CheckBoolAttr(n, kLowerAsMultiDeviceFunctionAttr);
815 }
816
MarkedForXlaCompilation(const NodeDef & n)817 bool MarkedForXlaCompilation(const NodeDef& n) {
818 auto is_enabled = [&](std::string attr_name) -> bool {
819 auto it = n.attr().find(attr_name);
820 return it != n.attr().end() && (!it->second.s().empty() || it->second.b());
821 };
822 return is_enabled("_xla_compile_id") || is_enabled("_tpu_replicate") ||
823 is_enabled(kXlaMustCompileAttr);
824 }
825
IsExemptFromSideEffectsExecutionValidation(const string & op)826 const bool IsExemptFromSideEffectsExecutionValidation(const string& op) {
827 static const auto* exemption = new absl::flat_hash_set<string>(
828 {// LINT.IfChange
829 // Op types that should not run in program order, e.g. because they need
830 // to run asynchronously to avoid deadlock.
831 "CollectiveGather", "CollectiveGatherV2", "CollectiveReduce",
832 "CollectiveReduceV2", "CollectiveBcastSend", "CollectiveBcastRecv",
833 "CollectiveBcastSendV2", "CollectiveBcastRecvV2", "NcclAllReduce",
834 "Send", "Recv",
835
836 // Legacy random ops.
837 // See details in tensorflow/python/framework/auto_control_deps.py.
838 "RandomUniform", "RandomUniformInt", "RandomStandardNormal",
839 "ParameterizedTruncatedNormal", "TruncatedNormal", "RandomShuffle",
840 "Multinomial", "RandomGamma", "RandomGammaGrad", "RandomPoisson",
841 "RandomPoissonV2",
842
843 // ReadVariableOp marked as stateful because it consumes DT_RESOURCE,
844 // but it can't generate any observable side-effect.
845 "ReadVariableOp",
846
847 // CudnnRNN ops are stateful but they can't generate any observable
848 // side-effect.
849 "CudnnRNN", "CudnnRNNBackprop", "CudnnRNNV2", "CudnnRNNV3",
850 "CudnnRNNBackpropV2", "CudnnRNNBackpropV3",
851
852 // TPUEmbedding EnqueueOps are stateful but this is only between ops with
853 // the same device_ordinal on the same host.
854 "EnqueueTPUEmbeddingSparseBatch", "EnqueueTPUEmbeddingIntegerBatch",
855 "EnqueueTPUEmbeddingSparseTensorBatch",
856 "EnqueueTPUEmbeddingRaggedTensorBatch",
857
858 // SaveV2 and RestoreV2 should be allowed to operate in parallel on
859 // multiple hosts.
860 "SaveV2", "RestoreV2"});
861 // LINT.ThenChange(//tensorflow/python/framework/auto_control_deps.py)
862 return exemption->contains(op);
863 }
864
865 // Validates that all side effects inside function body will be executed after
866 // function inlining. We do it by looking for a path from stateful ops, to one
867 // of the output control sources.
868 //
869 // When function executed via FunctionLibraryRuntime we do not have to check
870 // this, because `PruneFunctionBody` has special pruning rules for stateful ops.
ValidateSideEffectsExecution(const FunctionBody & fbody,OutputControlSource output_control_source,bool has_outgoing_control_edges,bool validate_outgoing_control_edge=true)871 Status ValidateSideEffectsExecution(
872 const FunctionBody& fbody, OutputControlSource output_control_source,
873 bool has_outgoing_control_edges,
874 bool validate_outgoing_control_edge = true) {
875 // Find all nodes that can produce side effects in the function body graph. We
876 // use 'is_stateful()' bit as an approximation of "has side effects" property.
877 std::vector<const Node*> fbody_side_effects;
878 absl::c_copy_if(
879 fbody.graph->nodes(), std::back_inserter(fbody_side_effects),
880 [](const Node* n) {
881 return n->op_def().is_stateful() && !n->IsArg() && !n->IsRetval() &&
882 !IsExemptFromSideEffectsExecutionValidation(n->type_string());
883 });
884
885 // When graph executed in TF-2.0 context with automatic control dependencies
886 // tracking, absence of outgoing control edge indicates that no one is
887 // interested in observing side effects, so it is safe to inline the function
888 // body, even if some side-effects will not be executed.
889 if (!fbody_side_effects.empty() && !has_outgoing_control_edges) {
890 const string error_message =
891 "Can't guarantee execution of function side-effects after inlining. "
892 "Function call node has no outgoing control edges.";
893 if (validate_outgoing_control_edge) {
894 return errors::Internal(error_message);
895 } else {
896 VLOG(3) << error_message;
897 }
898 }
899
900 // Find all nodes in the function body that will be used as control sources.
901 absl::flat_hash_set<const Node*> control_sources;
902 if (output_control_source == OutputControlSource::kDataOutputs) {
903 control_sources = {fbody.ret_nodes.begin(), fbody.ret_nodes.end()};
904 } else if (output_control_source == OutputControlSource::kControlOutputs) {
905 control_sources = {fbody.control_ret_nodes.begin(),
906 fbody.control_ret_nodes.end()};
907 }
908
909 for (const Node* side_effect : fbody_side_effects) {
910 VLOG(4) << "Check that node " << side_effect->name()
911 << " will execute after inlining.";
912 bool will_execute = false;
913
914 const auto is_control_source = [&](const Node* n) -> void {
915 const auto it = control_sources.find(n);
916 if (it != control_sources.end()) {
917 VLOG(4) << "Found a path to control source: " << side_effect->name()
918 << " ---> " << (*it)->name();
919 will_execute = true;
920 }
921 };
922
923 DFSFrom(*fbody.graph, {side_effect}, /*enter=*/is_control_source,
924 /*leave=*/{}, NodeComparatorName{});
925
926 if (!will_execute) {
927 return errors::Internal(
928 "Can't guarantee execution of a side-effectful node, that is not "
929 "reachable from function control source. Function body node: ",
930 SummarizeNode(*side_effect));
931 }
932 }
933
934 return Status::OK();
935 }
936
937 // Validates that no dead tensor can reach function output.
ValidateNoDeadOutputs(const FunctionLibraryDefinition & flib_def,const FunctionBody & fbody)938 Status ValidateNoDeadOutputs(const FunctionLibraryDefinition& flib_def,
939 const FunctionBody& fbody) {
940 absl::flat_hash_set<const Node*> output_nodes = {fbody.ret_nodes.begin(),
941 fbody.ret_nodes.end()};
942
943 // Find all nodes that can produce dead tensors.
944 std::vector<const Node*> dead_tensor_sources;
945 for (const Node* n : fbody.graph->nodes()) {
946 if (n->IsSwitch()) {
947 VLOG(4) << "Add dead tensors source. Switch node: " << n->name();
948 dead_tensor_sources.push_back(n);
949 continue;
950 }
951
952 // Native function call can also produce dead tensors if the function body
953 // has mergeless switches.
954 const FunctionDef* fdef = flib_def.Find(n->type_string());
955 if (fdef != nullptr) {
956 std::unique_ptr<FunctionBody> nested_fbody;
957
958 NameAttrList func;
959 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(n->def(), &func));
960 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
961 &flib_def, &nested_fbody));
962
963 if (!ValidateNoDeadOutputs(flib_def, *nested_fbody).ok()) {
964 VLOG(4) << "Add dead tensors source. Function call: " << func.name()
965 << " node=" << n->name();
966 dead_tensor_sources.push_back(n);
967 }
968 }
969 }
970
971 for (const Node* dead_tensor_source : dead_tensor_sources) {
972 bool has_dead_output = false;
973
974 const auto is_output_node = [&](const Node* n) -> void {
975 const auto it = output_nodes.find(n);
976 if (it != output_nodes.end()) {
977 VLOG(4) << "Found a path to output node from dead tensor source: "
978 << dead_tensor_source->name() << " ---> " << (*it)->name();
979 has_dead_output = true;
980 }
981 };
982
983 // Stop DFS traversal at a Merge node or if already found a dead output.
984 const auto stop_traversal = [&has_dead_output](const Edge& edge) -> bool {
985 return !edge.src()->IsMerge() || has_dead_output;
986 };
987
988 DFSFrom(*fbody.graph, {dead_tensor_source}, /*enter=*/is_output_node,
989 /*leave=*/{}, NodeComparatorName{},
990 /*edge_filter=*/stop_traversal);
991
992 if (has_dead_output) {
993 return errors::Internal(
994 "Can't inline a function with dead outputs. Dead tensor source: ",
995 SummarizeNode(*dead_tensor_source));
996 }
997 }
998
999 return Status::OK();
1000 }
1001
1002 // Makes an instance of FunctionBody for inlining from a Node.
MakeFunctionBodyForInlining(const Node & node,const FunctionLibraryDefinition & flib_def,std::unique_ptr<FunctionBody> * fbody)1003 Status MakeFunctionBodyForInlining(const Node& node,
1004 const FunctionLibraryDefinition& flib_def,
1005 std::unique_ptr<FunctionBody>* fbody) {
1006 VLOG(3) << "Make function body for inlining: " << SummarizeNode(node);
1007
1008 // Finds a FunctionDef in a library and verifies that it exists.
1009 const auto find_fdef = [&flib_def, &node](
1010 const string& name,
1011 const FunctionDef** fdef) -> Status {
1012 if ((*fdef = flib_def.Find(name)) == nullptr) {
1013 return errors::Internal(
1014 "Was not able to find a function definition (name=", name,
1015 ") for a function call: ", SummarizeNode(node));
1016 }
1017 return Status::OK();
1018 };
1019
1020 // SymbolicGradient is a special "function call" op, which has been
1021 // deprecated for a while, but we still support for compatibility reasons.
1022 if (node.type_string() == FunctionLibraryDefinition::kGradientOp) {
1023 NameAttrList func;
1024 TF_RETURN_IF_ERROR(GetNodeAttr(node.attrs(), kFuncAttr, &func));
1025
1026 const string grad = flib_def.FindGradient(func.name());
1027
1028 if (!grad.empty()) {
1029 // Function has a custom gradient registered in a library.
1030 const FunctionDef* grad_fdef;
1031 TF_RETURN_IF_ERROR(find_fdef(grad, &grad_fdef));
1032
1033 VLOG(4) << "Instantiate a custom SymbolicGradient: gradient=" << grad
1034 << " (function=" << func.name() << ")";
1035 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1036 *grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1037
1038 } else if (flib_def.Find(func.name()) == nullptr) {
1039 // Function is not really a function, but a primitive op.
1040 gradient::Creator creator;
1041 TF_RETURN_IF_ERROR(gradient::GetOpGradientCreator(func.name(), &creator));
1042 if (creator == nullptr) {
1043 return errors::InvalidArgument("No gradient is defined for ",
1044 func.name());
1045 }
1046 FunctionDef grad_fdef;
1047 TF_RETURN_IF_ERROR(creator(AttrSlice(&func.attr()), &grad_fdef));
1048
1049 VLOG(4) << "Instantiate a SymbolicGradient for a primitive op: "
1050 << func.name();
1051 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(
1052 grad_fdef, AttrSlice(&func.attr()), &flib_def, fbody));
1053
1054 } else {
1055 // Build a gradient graph from the function body.
1056 const FunctionDef* fdef;
1057 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1058
1059 VLOG(4) << "Instantiate a SymbolicGradient for a function: "
1060 << func.name();
1061 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1062 &flib_def, fbody));
1063 *fbody = SymbolicGradient(**fbody);
1064 }
1065
1066 } else {
1067 NameAttrList func;
1068 TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node.def(), &func));
1069 const FunctionDef* fdef;
1070 TF_RETURN_IF_ERROR(find_fdef(func.name(), &fdef));
1071
1072 VLOG(4) << "Instantiate a function call: function=" << func.name();
1073 TF_RETURN_IF_ERROR(FunctionDefToBodyHelper(*fdef, AttrSlice(&func.attr()),
1074 &flib_def, fbody));
1075 }
1076
1077 return Status::OK();
1078 }
1079
1080 // Adds a control edges from each data input to the 'caller' to enforce strict
1081 // inputs semantics (all inputs are ready and alive). This is required when:
1082 //
1083 // 1) The function takes resources as inputs, and it doesn't have incoming
1084 // control edges. In Tensorflow v2 context (eager mode) this should never
1085 // happen, because automatic control dependencies tracking will add a
1086 // control edge from the last op touching the resource. However such graphs
1087 // might be produced by legacy v1 code without automatic dependency
1088 // tracking. In this case strict function call semantics is required for
1089 // enforcing side effects execution order.
1090 //
1091 // 2) One of the inputs is consuming Enter[is_constant=true] node, in which
1092 // case it will be always alive, and potentially can lead to partial
1093 // function execution after the last loop execution.
1094 //
1095 // Both of these cases would be considered illegal by construction in Tensorflow
1096 // V2, however we have to guarantee that graphs constructed with Tensorflow V1
1097 // will produce correct results.
AddStrictInputSemantics(Node * caller,Graph * g)1098 void AddStrictInputSemantics(Node* caller, Graph* g) {
1099 absl::flat_hash_set<const Node*> existing_control_sources;
1100 for (const Edge* edge : caller->in_edges()) {
1101 if (edge->IsControlEdge()) {
1102 existing_control_sources.insert(edge->src());
1103 }
1104 }
1105
1106 const bool has_incoming_control_edges = !existing_control_sources.empty();
1107
1108 const bool has_resource_input =
1109 absl::c_any_of(caller->input_types(),
1110 [](const DataType dtype) { return dtype == DT_RESOURCE; });
1111
1112 const bool has_constant_enter_input =
1113 absl::c_any_of(caller->in_edges(), [](const Edge* edge) {
1114 Node* src = edge->src();
1115 return src->IsEnter() && CheckBoolAttr(src, "is_constant");
1116 });
1117
1118 const bool requires_strict_semantics =
1119 (!has_incoming_control_edges && has_resource_input) || // Case #1
1120 (has_constant_enter_input); // Case #2
1121 if (!requires_strict_semantics) return;
1122
1123 std::set<const Node*> data_inputs;
1124 for (const Edge* edge : caller->in_edges()) {
1125 if (!edge->IsControlEdge() &&
1126 !existing_control_sources.contains(edge->src())) {
1127 data_inputs.insert(edge->src());
1128 }
1129 }
1130
1131 VLOG(3) << "Add control edges from all data inputs to enforce strict "
1132 "semantics with regard to function inputs";
1133
1134 // Do not add control edges from placeholders, because it will prevent
1135 // pruning, and they can't produce any side effects anyway.
1136 const auto is_placeholder = [](const Node* node) -> bool {
1137 return node->type_string() == "Placeholder";
1138 };
1139
1140 for (const Node* node : data_inputs) {
1141 if (is_placeholder(node)) continue;
1142 g->AddControlEdge(g->FindNodeId(node->id()), caller,
1143 /*allow_duplicates=*/true);
1144 }
1145 }
1146
1147 // Adds a control edge from a frame node if the 'caller' is executing inside a
1148 // While loop (see control_flow.h for the 'frame' node explanation).
AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo> & info,Node * caller,Graph * g)1149 void AddFrameForwardingControlEdge(const std::vector<ControlFlowInfo>& info,
1150 Node* caller, Graph* g) {
1151 // All nodes added to the graph by v2 control flow lowering and function
1152 // inlining are guaranteed to have control edges to nested function calls.
1153 int info_size = info.size();
1154 if (caller->id() >= info_size) return;
1155
1156 // Check if a lowered node is executing inside a while loop.
1157 const Node* frame = info[caller->id()].frame;
1158 const bool is_in_while_loop = frame->id() != Graph::kSourceId;
1159 if (!is_in_while_loop) return;
1160
1161 // Check if a node already has an incoming control edge. All incoming edges
1162 // must be from the same execution frame (executor.cc invariant), so if we
1163 // already have an incoming control edge, it's guaranteed that it will "carry"
1164 // the same frame as all regular inputs.
1165 const bool has_incoming_control_edges =
1166 absl::c_any_of(caller->in_edges(),
1167 [](const Edge* edge) { return edge->IsControlEdge(); });
1168 if (has_incoming_control_edges) return;
1169
1170 VLOG(3) << "Add a frame forwarding control edge: from=" << frame->name()
1171 << " to=" << caller->name();
1172 Node* enter = g->FindNodeId(frame->id());
1173 bool is_constant_enter = enter->attrs().Find("is_constant")->b();
1174 if (is_constant_enter) {
1175 // Enter[is_constant=true] is always alive. So we directly add a control
1176 // edge from that.
1177 g->AddControlEdge(enter, caller);
1178 } else {
1179 // Enter[is_constant=false] activates nodes only in 0th iteration so we
1180 // add an edge from the Merge node which is activated in every iteration.
1181 // A non-constant Enter node must have an edge to a Merge node.
1182 auto it = absl::c_find_if(enter->out_edges(), [](const Edge* e) {
1183 return !e->IsControlEdge() && e->dst()->IsMerge();
1184 });
1185 if (it != enter->out_edges().end()) {
1186 g->AddControlEdge((*it)->dst(), caller);
1187 } else {
1188 LOG(WARNING) << "Enter[is_constant=false] node: " << enter->name()
1189 << " does not have an outgoing edge to a Merge.";
1190 }
1191 }
1192 }
1193
1194 // Inlines all function calls that are safe for inlining into the main graph.
1195 // Also lowers control flow V2 ops (functional If/While) into the V1 low level
1196 // ops (Switch/Merge/...).
1197 //
1198 // Runs a placer after inlining, to keep all nodes in a graph placed.
InlineFunctionCalls(const GrapplerItem & item,const RewriterConfig::Toggle opt_level,const bool lower_control_flow,GraphDef * output_graph)1199 Status InlineFunctionCalls(const GrapplerItem& item,
1200 const RewriterConfig::Toggle opt_level,
1201 const bool lower_control_flow,
1202 GraphDef* output_graph) {
1203 bool is_aggressive = opt_level == RewriterConfig::AGGRESSIVE;
1204 VLOG(2) << "Inline function calls: grappler_item_id=" << item.id
1205 << " (aggressive_mode=" << is_aggressive << ")";
1206
1207 FunctionLibraryDefinition flib_def =
1208 FunctionLibraryDefinition(OpRegistry::Global(), item.graph.library());
1209 std::unique_ptr<Graph> graph = absl::make_unique<Graph>(flib_def);
1210
1211 GraphConstructorOptions graph_constructor_options;
1212 graph_constructor_options.allow_internal_ops = true;
1213 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(graph_constructor_options,
1214 item.graph, graph.get()));
1215
1216 using NodeNames = absl::flat_hash_set<absl::string_view>;
1217 NodeNames fetch_nodes;
1218 fetch_nodes.reserve(item.fetch.size());
1219 for (const string& fetch : item.fetch) {
1220 fetch_nodes.insert(ParseTensorName(fetch).node());
1221 }
1222 NodeNames keep_nodes(item.keep_ops.begin(), item.keep_ops.end());
1223
1224 std::vector<string> inlined_function_names;
1225
1226 // Do not inline function call nodes that are part of a feed set.
1227 NodeNames feed_nodes;
1228 feed_nodes.reserve(item.feed.size());
1229 for (const std::pair<std::string, Tensor>& feed : item.feed) {
1230 feed_nodes.insert(ParseTensorName(feed.first).node());
1231 }
1232
1233 // If a function call is inside a While loop, it must have an incoming control
1234 // edge, because it will be used to pass execution frame into the function
1235 // body. All nodes without inputs in the function body (e.g. Const and NoOp)
1236 // will be added an extra control edge from the 'input_control_node'.
1237 std::vector<ControlFlowInfo> control_flow_info;
1238 TF_RETURN_IF_ERROR(BuildControlFlowInfo(graph.get(), &control_flow_info));
1239
1240 // Function inlining always adds new nodes to the end of the list, so we keep
1241 // iterating until we are out of nodes.
1242 for (int i = 2; i < graph->num_node_ids(); ++i) {
1243 Node* n = graph->FindNodeId(i);
1244 if (n == nullptr) continue; // deleted node
1245
1246 // Special case for lowering functional control flow ops. We do not rely on
1247 // LowerFunctionOpsPass because in Grappler we have to be more restrictive
1248 // about what type of function calls we are allowed to inline.
1249 if (lower_control_flow && LowerUsingSwitchMergeIsOn(n)) {
1250 VLOG(2) << "Lower functional control flow op: " << SummarizeNode(*n);
1251 AddStrictInputSemantics(n, graph.get());
1252 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1253
1254 if (n->IsIfNode()) {
1255 TF_RETURN_IF_ERROR(RewriteIfNode(n, graph.get(), false));
1256 } else if (n->IsCaseNode()) {
1257 TF_RETURN_IF_ERROR(RewriteCaseNode(n, graph.get(), false));
1258 } else if (n->IsWhileNode()) {
1259 TF_RETURN_IF_ERROR(RewriteWhileNode(n, graph.get(), false));
1260 }
1261 continue;
1262 }
1263
1264 // Skip nodes that are not function calls.
1265 if (!IsFunctionCall(flib_def, *n)) continue;
1266 // Skip function calls that we plan to compile later.
1267 if (MarkedForXlaCompilation(n->def())) continue;
1268 // Skip nodes in a feed set.
1269 if (feed_nodes.contains(n->name())) continue;
1270
1271 // Function body that we will inline into the main graph. It can be a
1272 // function instantiation, or a gradient function instantiated from
1273 // SymbolicGradient op.
1274 std::unique_ptr<FunctionBody> fbody;
1275 TF_RETURN_IF_ERROR(MakeFunctionBodyForInlining(*n, flib_def, &fbody));
1276
1277 InlineFunctionBodyOptions inline_options;
1278 // Ignore '_noinline' flag in aggressive mode.
1279 inline_options.ignore_noinline = is_aggressive;
1280
1281 // Function calls created after inlining If/While ops are always inlined as
1282 // multi-device functions and are not required to pass additional Grappler
1283 // validations (side effects execution validation below).
1284 bool force_inline_as_multi_device = LowerAsMultiDeviceFunctionIsOn(n);
1285
1286 // `PartitionedCall` is a TF-2.0 function call mechanism for multi-device
1287 // functions:
1288 // a) Function can be multi-device.
1289 // b) Automatic control dependencies tracking guarantees that all function
1290 // side-effectful nodes will have a path to one of the control outputs.
1291 // Control outputs and control edges between side-effectful (stateful)
1292 // nodes are used to explicitly mark the nodes that must execute, and to
1293 // define their execution order.
1294 if (n->IsPartitionedCall() || force_inline_as_multi_device) {
1295 inline_options.output_control_src = OutputControlSource::kControlOutputs;
1296 inline_options.inlined_function_body_placer =
1297 InlinedFunctionBodyPlacer::MultiDevice();
1298 } else {
1299 inline_options.output_control_src = OutputControlSource::kDataOutputs;
1300 inline_options.inlined_function_body_placer =
1301 InlinedFunctionBodyPlacer::SingleDevice();
1302 }
1303
1304 if (fetch_nodes.contains(n->name())) {
1305 inline_options.keep_caller_node = KeepCallerNode::kFetchable;
1306 } else if (keep_nodes.contains(n->name())) {
1307 inline_options.keep_caller_node = KeepCallerNode::kTargetable;
1308 } else {
1309 inline_options.keep_caller_node = KeepCallerNode::kDoNotKeep;
1310 }
1311
1312 // Basic validation rules defined in common_runtime shared by all functions.
1313 Status can_inline_function_call =
1314 ValidateInlining(n, fbody.get(), inline_options);
1315
1316 // Additional validation rules defined only in Grappler.
1317 // TODO(ezhulenev): Move it to common_runtime InlineFunctionBodyOptions?
1318 if (can_inline_function_call.ok()) {
1319 bool has_outgoing_control_edges = absl::c_any_of(
1320 n->out_edges(),
1321 [](const Edge* edge) { return edge->IsControlEdge(); });
1322
1323 can_inline_function_call = ValidateSideEffectsExecution(
1324 *fbody, inline_options.output_control_src,
1325 has_outgoing_control_edges);
1326
1327 if (!can_inline_function_call.ok() &&
1328 (is_aggressive || force_inline_as_multi_device)) {
1329 VLOG(2) << "Ignore error: " << can_inline_function_call.error_message();
1330 can_inline_function_call = Status::OK();
1331 }
1332 }
1333 if (can_inline_function_call.ok()) {
1334 can_inline_function_call = ValidateNoDeadOutputs(flib_def, *fbody);
1335 }
1336
1337 if (can_inline_function_call.ok()) {
1338 VLOG(2) << "Inline function call node: " << n->name();
1339 AddStrictInputSemantics(n, graph.get());
1340 AddFrameForwardingControlEdge(control_flow_info, n, graph.get());
1341
1342 TF_RETURN_IF_ERROR(InlineFunctionBody(flib_def, graph.get(), n,
1343 fbody.get(), inline_options));
1344 inlined_function_names.push_back(fbody->fdef.signature().name());
1345
1346 } else {
1347 VLOG(2) << "Failed to inline function call node: "
1348 << can_inline_function_call.error_message();
1349 }
1350 }
1351
1352 VLOG(4) << "Inlined " << inlined_function_names.size()
1353 << " function calls: " << absl::StrJoin(inlined_function_names, ", ");
1354
1355 // ------------------------------------------------------------------------ //
1356 // Grappler receives the graph after PRE_PLACEMENT, Placer, and POST_PLACEMENT
1357 // passes, so each node has a valid device assignment. After function inlining
1358 // and control flow V2 lowering we have to keep graph placed.
1359
1360 if (inlined_function_names.empty()) {
1361 VLOG(3) << "Not placing graph after function inlining"
1362 << " (did not inline any of the function calls).";
1363
1364 } else if (item.devices().empty()) {
1365 // If there are no devices available for placer, we do not place graph after
1366 // function inlining. This happens when Grappler is optimizing the function
1367 // library, or when a graph optimized "offline", without an active runtime
1368 // session, for example as a part of batch job for graph
1369 // analysis/optimization. GrapplerItem instantiated from a function library
1370 // doesn't have to be fully placed after all optimizations; it will be
1371 // placed by the function library runtime before execution.
1372 VLOG(3) << "Not placing graph after function inlining"
1373 << " (device set is empty)";
1374
1375 } else {
1376 // If we are running in an active runtime session, Grappler will get the
1377 // graph after initial placing is done, and we should have devices for the
1378 // placer.
1379 VLOG(3) << "Run placer for the graph after function inlining. "
1380 << "Devices: [" << absl::StrJoin(item.devices(), ", ") << "]";
1381
1382 DeviceSet device_set; // does not own devices
1383 std::vector<std::unique_ptr<Device>> fake_devices; // owns fake devices
1384
1385 for (const string& name : item.devices()) {
1386 auto device = absl::make_unique<FakeDevice>(name);
1387 device_set.AddDevice(device.get());
1388 fake_devices.push_back(std::move(device));
1389 }
1390
1391 Placer placer(graph.get(), item.id, &device_set);
1392 TF_RETURN_IF_ERROR(placer.Run());
1393 }
1394
1395 graph->ToGraphDef(output_graph);
1396 return Status::OK();
1397 }
1398
1399 // Restores tensor mapping after function specialization: all inputs must be
1400 // connected to valid nodes.
RestoreTensorMapping(const FunctionOptimizerContext & ctx,GraphDef * optimized_graph)1401 void RestoreTensorMapping(const FunctionOptimizerContext& ctx,
1402 GraphDef* optimized_graph) {
1403 if (ctx.tensor_mapping().empty()) return;
1404
1405 // During function specialization, we might prune unused function outputs. We
1406 // need to "close the holes" that might appear in the function outputs.
1407 //
1408 // Example: prune unused output "f:1"
1409 //
1410 // f = my_func[T=float](...) f = my_func_specialized[T=float](...)
1411 // a = Identity(f:0) -> a = Identity(f:0)
1412 // b = Identity(f:2) b = Identity(f:1)
1413 //
1414 // Tensor mapping (size=1): [f:2 -> f:1]
1415 for (NodeDef& node : *optimized_graph->mutable_node()) {
1416 for (int idx = 0; idx < node.input_size(); ++idx) {
1417 TensorId input_tensor = ParseTensorName(node.input(idx));
1418 if (input_tensor.index() == Graph::kControlSlot) break;
1419
1420 auto mapping = ctx.tensor_mapping().find(input_tensor);
1421 if (mapping != ctx.tensor_mapping().end()) {
1422 node.set_input(idx, mapping->second.ToString());
1423 }
1424 }
1425 }
1426 }
1427
1428 } // namespace
1429
RunFunctionOptimizerPass(const GrapplerItem & item,GraphDef * optimized_graph) const1430 Status FunctionOptimizer::RunFunctionOptimizerPass(
1431 const GrapplerItem& item, GraphDef* optimized_graph) const {
1432 VLOG(3) << "Run function optimizer pass: grappler_item_id=" << item.id;
1433
1434 // Inline all function calls into a graph using common_runtime/function
1435 // implementation (see `InlineFunctionBody` function documentation).
1436 GraphDef graph_after_inlining;
1437 TF_RETURN_IF_ERROR(InlineFunctionCalls(item, opt_level_, lower_control_flow_,
1438 &graph_after_inlining));
1439
1440 // Specialize function calls that we could not inline.
1441 FunctionOptimizerContext ctx(item, opt_level_, graph_after_inlining);
1442
1443 for (const NodeDef& node : graph_after_inlining.node()) {
1444 // Function specialization can modify optimized graph only by adding new
1445 // nodes, we can check node size to make sure that graph was not modified.
1446 const int num_nodes_before = optimized_graph->node_size();
1447 const auto is_graph_modified = [&]() {
1448 int num_nodes = optimized_graph->node_size();
1449 DCHECK_GE(num_nodes, num_nodes_before) << "Nodes should not be removed";
1450 return num_nodes > num_nodes_before;
1451 };
1452
1453 // Copy node from the `graph_after_inlining` to the `optimized_graph`.
1454 const auto copy_node = [&]() { *optimized_graph->add_node() = node; };
1455
1456 // Find if a node is a function call (direct or indirect).
1457 const FunctionDef* func = FindFunctionCall(ctx, node);
1458 if (func == nullptr) {
1459 copy_node();
1460 continue;
1461 }
1462
1463 const string& func_name = func->signature().name();
1464
1465 // Specialize it to its instantiation context if it has something worth
1466 // specializing.
1467 const bool specialization_worthy = IsParametrized(*func) ||
1468 HasTrulyConstInputs(node, ctx) ||
1469 HasUnusedOutputs(node, *func, ctx);
1470
1471 // Do not specialize if function has custom gradient or marked nospecialize.
1472 const string grad_func = ctx.function_library().FindGradient(func_name);
1473 const bool no_specialize =
1474 !grad_func.empty() || ctx.IsFeedNode(node.name()) ||
1475 MarkedNoSpecialize(*func) || MarkedForXlaCompilation(node);
1476
1477 if (specialization_worthy && !no_specialize) {
1478 // TODO(ezhulenev): Specialize function call if input has a known shape.
1479 // Specialize function body for its instantiation attributes and inputs.
1480 Status status = SpecializeFunction(node, *func, &ctx, optimized_graph);
1481 if (!status.ok() && is_graph_modified()) {
1482 return status;
1483 } else if (!status.ok() && !is_graph_modified()) {
1484 VLOG(3) << "Skip specialization error: " << status.error_message();
1485 copy_node();
1486 }
1487 continue;
1488 } else {
1489 VLOG(2) << "Skip function specialization: " << func->signature().name();
1490 copy_node();
1491 }
1492 }
1493
1494 RestoreTensorMapping(ctx, optimized_graph);
1495
1496 // Preserve the graph version.
1497 *optimized_graph->mutable_versions() = item.graph.versions();
1498 // Prune unreachable function from the library.
1499 *optimized_graph->mutable_library() =
1500 PruneFunctionLibrary(ctx.function_library(), *optimized_graph);
1501
1502 return Status::OK();
1503 }
1504
Optimize(Cluster *,const GrapplerItem & item,GraphDef * optimized_graph)1505 Status FunctionOptimizer::Optimize(Cluster*, const GrapplerItem& item,
1506 GraphDef* optimized_graph) {
1507 // Nothing to do here.
1508 if (item.graph.library().function_size() == 0) {
1509 return errors::Aborted("Nothing to do.");
1510 }
1511
1512 TF_RETURN_IF_ERROR(RunFunctionOptimizerPass(item, optimized_graph));
1513
1514 return Status::OK();
1515 }
1516
Feedback(Cluster * cluster,const GrapplerItem & item,const GraphDef & optimized_graph,double result)1517 void FunctionOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
1518 const GraphDef& optimized_graph,
1519 double result) {
1520 // Nothing to do for FunctionOptimizer.
1521 }
1522
1523 } // end namespace grappler
1524 } // end namespace tensorflow
1525