1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
32 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
33 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
34 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
37 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
38 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
41 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
43 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
44 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/remapper.h"
46 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
47 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
48 #include "tensorflow/core/grappler/utils/canonicalizer.h"
49 #include "tensorflow/core/grappler/utils/colocation.h"
50 #include "tensorflow/core/grappler/utils/functions.h"
51 #include "tensorflow/core/grappler/utils/topological_sort.h"
52 #include "tensorflow/core/grappler/utils/tpu.h"
53 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 #include "tensorflow/core/util/ptr_util.h"
58 #include "tensorflow/core/util/xla_config_registry.h"
59
60 namespace tensorflow {
61 namespace grappler {
62
63 namespace {
64
65 constexpr int kDefaultNumberOfIterations = 2;
66 constexpr int kDefaultMinGraphNodes = 4;
67
NumEdges(const GraphDef & graph)68 int64 NumEdges(const GraphDef& graph) {
69 int64 num_edges = 0;
70 for (const auto& node : graph.node()) {
71 num_edges += node.input_size();
72 }
73 return num_edges;
74 }
75
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)76 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
77 return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
78 after.node_size() - before.node_size(), "), ",
79 NumEdges(after), " edges (",
80 NumEdges(after) - NumEdges(before), ")");
81 }
82
NumIterations(const RewriterConfig & cfg)83 int NumIterations(const RewriterConfig& cfg) {
84 return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
85 ? kDefaultNumberOfIterations
86 : cfg.meta_optimizer_iterations();
87 }
88
89 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)90 bool IsRunOnceOptimizer(const string& name) {
91 return name == "layout" || name == "memory_optimizer" ||
92 name == "loop_optimizer" || name == "auto_mixed_precision" ||
93 name == "auto_mixed_precision_mkl";
94 }
95
96 // Creates a function library stub from a real function library: copy only
97 // signatures and attributes of all the function defined in fdef_lib. This stub
98 // can be swapped with real function library in a graph, before passing it to
99 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)100 FunctionDefLibrary GetFunctionDefLibraryStub(
101 const FunctionDefLibrary& fdef_lib) {
102 FunctionDefLibrary stub;
103 for (const FunctionDef& fn : fdef_lib.function()) {
104 FunctionDef* fn_stub = stub.mutable_function()->Add();
105 *(fn_stub->mutable_signature()) = fn.signature();
106 *(fn_stub->mutable_attr()) = fn.attr();
107 *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
108 *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
109 }
110 *stub.mutable_gradient() = fdef_lib.gradient();
111 return stub;
112 }
113
DeadlineMicroSeconds(const RewriterConfig & cfg)114 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
115 const uint64 kTwentyMinutesInUsec = 20 * 60 * 1000 * 1000;
116 if (cfg.meta_optimizer_timeout_ms() < 0) {
117 return 0;
118 } else {
119 return cfg.meta_optimizer_timeout_ms() == 0
120 ? Env::Default()->NowMicros() + kTwentyMinutesInUsec
121 : Env::Default()->NowMicros() +
122 cfg.meta_optimizer_timeout_ms() * 1000;
123 }
124 }
125
126 // A helper function to decide whether to enable the automatic mixed precision
127 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)128 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
129 if (opt_level == RewriterConfig::ON ||
130 opt_level == RewriterConfig::AGGRESSIVE) {
131 return true;
132 }
133 return false;
134 }
135
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)136 bool IsXlaGlobalJitOn(
137 const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
138 xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
139 xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
140 // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
141 // graphs. This is a conservative approach that turns off the memory optimizer
142 // when we are sure that all graphs will be processed by XLA JIT.
143 bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
144 xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
145 (xla_global_jit_level.general == OptimizerOptions::ON_1 ||
146 xla_global_jit_level.general == OptimizerOptions::ON_2);
147 return is_on;
148 }
149
150 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,OptimizerOptions::GlobalJitLevel jit_level_in_session_opts)151 bool MemoryOptimizerEnabled(
152 RewriterConfig::MemOptType mem_opt_type,
153 OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
154 // Disable the default memory optimizer when XLA JIT is ON as it hurts the
155 // XLA JIT performance. The (current) XLA clustering can result in loss of
156 // concurrency between kernel compute and memory copies. As such, it usually
157 // loses the concurrency needed to hide the latencies of the inserted swap-ins
158 // and swap-outs and incurs great performance overhead. Remove this check when
159 // the XLA JIT can better deal with the concurrency.
160 if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
161 IsXlaGlobalJitOn(jit_level_in_session_opts)) {
162 return false;
163 }
164
165 return mem_opt_type != RewriterConfig::NO_MEM_OPT;
166 }
167
168 } // namespace
169
170 #define MK_OPT(NAME, VALUE) \
171 if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
172
LowerControlFlow() const173 bool MetaOptimizer::LowerControlFlow() const {
174 if (config_proto_.experimental().executor_type() ==
175 "SINGLE_THREADED_EXECUTOR")
176 return false;
177
178 if (config_proto_.experimental().use_tfrt()) return false;
179
180 return true;
181 }
182
MakeNewOptimizer(const string & optimizer) const183 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
184 const string& optimizer) const {
185 MK_OPT("pruning", new ModelPruner());
186 MK_OPT("function",
187 new FunctionOptimizer(cfg_.function_optimization(),
188 /*lower_control_flow=*/LowerControlFlow()));
189 MK_OPT("constfold",
190 new ConstantFolding(
191 cpu_device_,
192 cfg_.experimental_disable_compressed_tensor_optimization(),
193 !cfg_.experimental_disable_folding_quantization_emulation()));
194 MK_OPT("shape", new ShapeOptimizer());
195 MK_OPT("remap", new Remapper(cfg_.remapping()));
196 MK_OPT("layout", new GenericLayoutOptimizer(
197 /*optimization level*/ cfg_.layout_optimizer(),
198 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
199 MK_OPT("auto_mixed_precision",
200 new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
201 MK_OPT("auto_mixed_precision_mkl",
202 new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
203 MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
204 MK_OPT("common_subgraph_elimination",
205 new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
206 MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
207 MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
208 MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
209 MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
210 MK_OPT("debug_stripper", new DebugStripper());
211 MK_OPT("scoped_allocator",
212 new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
213 cfg_.scoped_allocator_opts()));
214 MK_OPT("pin_to_host",
215 new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
216
217 return std::unique_ptr<GraphOptimizer>();
218 }
219
220 #undef MK_OPT
221
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)222 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
223 : cpu_device_(cpu_device),
224 config_proto_(cfg),
225 cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
226 DCHECK(cpu_device_ == nullptr ||
227 cpu_device_->attributes().device_type() == "CPU");
228 }
229
InitializeOptimizers(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const230 Status MetaOptimizer::InitializeOptimizers(
231 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
232 if (cfg_.disable_meta_optimizer()) {
233 return Status::OK();
234 }
235 if (!cfg_.disable_model_pruning()) {
236 optimizers->push_back(MakeUnique<ModelPruner>());
237 }
238 if (cfg_.implementation_selector() != RewriterConfig::OFF) {
239 optimizers->push_back(MakeUnique<ImplementationSelector>());
240 }
241 if (cfg_.function_optimization() != RewriterConfig::OFF) {
242 optimizers->push_back(MakeUnique<FunctionOptimizer>(
243 cfg_.function_optimization(),
244 /*lower_control_flow=*/LowerControlFlow()));
245 }
246 if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF &&
247 cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
248 optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
249 cfg_.common_subgraph_elimination()));
250 }
251 if (cfg_.debug_stripper() == RewriterConfig::ON) {
252 optimizers->push_back(MakeUnique<DebugStripper>());
253 }
254 if (cfg_.constant_folding() != RewriterConfig::OFF) {
255 optimizers->push_back(MakeUnique<ConstantFolding>(
256 cfg_.constant_folding(), cpu_device_,
257 cfg_.experimental_disable_compressed_tensor_optimization(),
258 !cfg_.experimental_disable_folding_quantization_emulation()));
259 }
260 if (cfg_.shape_optimization() != RewriterConfig::OFF) {
261 optimizers->push_back(MakeUnique<ShapeOptimizer>());
262 }
263 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
264 optimizers->push_back(
265 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
266 }
267 if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())) {
268 optimizers->push_back(
269 MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
270 }
271 if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
272 optimizers->push_back(MakeUnique<PinToHostOptimizer>());
273 }
274 if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
275 optimizers->push_back(
276 MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
277 }
278 if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
279 optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
280 /*optimization level*/ cfg_.layout_optimizer(),
281 /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
282 }
283 if (cfg_.remapping() != RewriterConfig::OFF) {
284 optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
285 }
286 if (cfg_.loop_optimization() != RewriterConfig::OFF) {
287 optimizers->push_back(
288 MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
289 }
290 if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
291 optimizers->push_back(
292 MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
293 }
294 auto global_jit_level =
295 config_proto_.graph_options().optimizer_options().global_jit_level();
296 if (MemoryOptimizerEnabled(cfg_.memory_optimization(), global_jit_level)) {
297 if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
298 optimizers->push_back(
299 // Use the default target node name prefix "gradients/"
300 MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
301 } else {
302 optimizers->push_back(MakeUnique<MemoryOptimizer>(
303 cfg_.memory_optimization(),
304 cfg_.memory_optimizer_target_node_name_scope()));
305 }
306 }
307 if (cfg_.auto_parallel().enable()) {
308 optimizers->push_back(
309 MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
310 }
311 if (cfg_.scoped_allocator_optimization()) {
312 optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
313 cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
314 }
315 return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
316 }
317
InitializeOptimizersByName(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const318 Status MetaOptimizer::InitializeOptimizersByName(
319 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
320 std::set<string> initialized_custom_optimizers;
321 for (const string& optimizer_name : cfg_.optimizers()) {
322 auto optimizer = MakeNewOptimizer(optimizer_name);
323 if (optimizer) {
324 VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
325 optimizers->push_back(std::move(optimizer));
326 continue;
327 }
328
329 auto custom_optimizer =
330 CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
331
332 if (custom_optimizer) {
333 VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
334 TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
335 config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
336 optimizers->push_back(std::move(custom_optimizer));
337 initialized_custom_optimizers.insert(optimizer_name);
338 } else {
339 VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
340 }
341 }
342 return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
343 optimizers);
344 }
345
InitializeCustomGraphOptimizers(const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const346 Status MetaOptimizer::InitializeCustomGraphOptimizers(
347 const std::set<string>& pre_initialized_optimizers,
348 std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
349 for (const auto& optimizer_config : cfg_.custom_optimizers()) {
350 if (pre_initialized_optimizers.find(optimizer_config.name()) !=
351 pre_initialized_optimizers.end()) {
352 continue;
353 }
354
355 auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
356 optimizer_config.name());
357
358 if (custom_optimizer) {
359 VLOG(2) << "Registered custom configurable graph optimizer: "
360 << optimizer_config.name();
361 TF_RETURN_IF_ERROR(
362 custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
363 optimizers->push_back(std::move(custom_optimizer));
364 } else {
365 // If there are no custom optimizers with given name, try to initialize a
366 // default optimizer. This way, custom configurable optimizers can be
367 // mixed with default optimizers in any order.
368 auto optimizer = MakeNewOptimizer(optimizer_config.name());
369 if (optimizer) {
370 VLOG(2) << "Registered default graph optimizer: "
371 << optimizer_config.name();
372 optimizers->push_back(std::move(optimizer));
373 continue;
374 }
375 VLOG(2) << "Can't register an optimizer by name: "
376 << optimizer_config.name();
377 }
378 }
379 return Status::OK();
380 }
381
382 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const383 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
384 for (const auto& config : cfg_.custom_optimizers()) {
385 if (config.name() == name) {
386 return &config;
387 }
388 }
389 return nullptr;
390 }
391
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const392 void MetaOptimizer::InitializeVerifiers(
393 std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
394 std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
395 const {
396 if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
397 VerifierConfig::ON) {
398 inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
399 }
400 if (cfg_.post_optimization_verifier_config().structure_verifier() ==
401 VerifierConfig::ON) {
402 post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
403 }
404 }
405
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)406 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
407 GraphDef* optimized_graph) {
408 int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
409 : cfg_.min_graph_nodes();
410 if (item.graph.node_size() < min_graph_nodes) {
411 VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
412 << " nodes.";
413 *optimized_graph = item.graph;
414 return Status::OK();
415 }
416
417 const uint64 start_us = Env::Default()->NowMicros();
418
419 std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
420 if (cfg_.optimizers().empty()) {
421 TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
422 } else {
423 TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
424 }
425
426 // Initialize the configured verifiers.
427 std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
428 std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
429 InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
430 if (inter_optimizer_verifiers.empty()) {
431 VLOG(2) << "No inter optimizer verifiers have been configured";
432 } else {
433 VLOG(2) << inter_optimizer_verifiers.size()
434 << " inter optimizer verifiers have been configured";
435 }
436 if (post_optimization_verifiers.empty()) {
437 VLOG(2) << "No post optimization verifiers have been configured";
438 } else {
439 VLOG(2) << post_optimization_verifiers.size()
440 << " post optimization verifiers have been configured";
441 }
442
443 VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
444 << " num_optimizers=" << optimizers.size()
445 << ", num nodes = " << item.graph.node_size();
446
447 if (optimizers.empty()) {
448 VLOG(3) << "Skipping graph optimization, no optimizers registered";
449 *optimized_graph = item.graph;
450 return Status::OK();
451 }
452
453 // Invariant: optimized_graph contains the most recently optimized version of
454 // the graph.
455 auto original_producer = item.graph.versions().producer();
456 optimized_graph->Swap(&item.graph);
457
458 GraphOptimizationResult optimization_result(item.id);
459 GraphOptimizer* sa_optimizer = nullptr;
460
461 // Constants in the graph are normally compressed after model_pruner.
462 // Do it here if model pruner is disabled.
463 if (cfg_.disable_model_pruning()) {
464 CompressConstants(optimized_graph);
465 }
466
467 for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
468 // Don't bother optimizing further if the graph is already tiny.
469 if (optimized_graph->node_size() < min_graph_nodes) {
470 VLOG(3) << "Stopping after iteration " << iteration
471 << ", graph is tiny (#nodes = " << optimized_graph->node_size()
472 << " < " << min_graph_nodes << ")";
473 break;
474 }
475
476 VLOG(4) << "Starting optimization iteration " << iteration;
477 if (VLOG_IS_ON(4)) {
478 DumpGraphDefToFile(
479 strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
480 reinterpret_cast<uintptr_t>(optimized_graph)),
481 *optimized_graph);
482 }
483
484 for (const auto& optimizer : optimizers) {
485 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
486 // Some optimizers can run only once.
487 if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
488 // Some must run only on the last iteration.
489 if (optimizer->name() == "scoped_allocator_optimizer") {
490 if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
491 continue;
492 }
493
494 TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
495 optimized_graph, &optimization_result));
496
497 if (iteration == 0 && optimizer->name() == "model_pruner") {
498 CompressConstants(optimized_graph);
499 }
500
501 if (VLOG_IS_ON(4)) {
502 DumpGraphDefToFile(
503 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
504 optimizer->name(), "_",
505 reinterpret_cast<uintptr_t>(optimized_graph)),
506 *optimized_graph);
507 }
508 for (const auto& verifier : inter_optimizer_verifiers) {
509 // TODO(ashwinm): Need to enforce verification_deadline.
510 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
511 }
512 }
513 if (VLOG_IS_ON(4)) {
514 DumpGraphDefToFile(
515 strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
516 reinterpret_cast<uintptr_t>(optimized_graph)),
517 *optimized_graph);
518 }
519 // TODO(ashwinm): Need to enforce verification_deadline.
520 for (const auto& verifier : post_optimization_verifiers) {
521 TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
522 }
523 }
524
525 // ScopedAllocatorOptimizer must run last.
526 if (sa_optimizer != nullptr) {
527 TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
528 optimized_graph, &optimization_result));
529 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
530 }
531
532 bool is_optimized = std::find_if(optimization_result.results.begin(),
533 optimization_result.results.end(),
534 [](const OptimizerResult& result) {
535 return result.status.ok();
536 }) != optimization_result.results.end();
537
538 // Record graph optimization result.
539 optimization_results_.push_back(optimization_result);
540
541 if (is_optimized) {
542 TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
543 ReassignColocation(optimized_graph);
544 // Make sure that the optimizers preserved the graph version.
545 DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
546 }
547
548 const uint64 end_us = Env::Default()->NowMicros();
549 metrics::UpdateGrapplerPassTime("OptimizeMainGraph", end_us - start_us);
550
551 return Status::OK();
552 }
553
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)554 Status MetaOptimizer::RunOptimizer(
555 GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
556 GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
557 const uint64 start_us = Env::Default()->NowMicros();
558
559 // If optimizer doesn't need a function library, we will replace it with a
560 // stub before running optimization, and will put it back at the end.
561 FunctionDefLibrary optimized_graph_function_library;
562 const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
563
564 // Replace function library in optimized graph with a stub.
565 if (!is_function_library_aware) {
566 VLOG(3) << "Replace function library with a stub for " << optimizer->name();
567 optimized_graph_function_library.Swap(optimized_graph->mutable_library());
568 *optimized_graph->mutable_library() =
569 GetFunctionDefLibraryStub(optimized_graph_function_library);
570 }
571
572 // This swaps the current optimized_graph into optimized item and
573 // resets optimized_graph to an empty graph.
574 optimized_graph->Swap(&optimized_item->graph);
575 *optimized_graph = GraphDef();
576 optimizer->set_deadline_usec(this->deadline_usec());
577 Status status =
578 optimizer->Optimize(cluster, *optimized_item, optimized_graph);
579 const uint64 end_us = Env::Default()->NowMicros();
580 const float duration_ms = (end_us - start_us) / 1000.0f;
581 metrics::UpdateGrapplerPassTime(optimizer->name(), end_us - start_us);
582
583 string message;
584 if (!status.ok()) {
585 optimized_graph->Swap(&optimized_item->graph);
586 if (errors::IsAborted(status)) {
587 // By convention we (ab-)use the Aborted error code to signal that the
588 // optimizer returned without performing any changes to the graph.
589 message = strings::StrCat(optimizer->name(),
590 " did nothing. time = ", duration_ms, "ms.");
591 // Swallow the non-critical error.
592 status = Status::OK();
593 } else if (errors::IsDeadlineExceeded(status)) {
594 message =
595 strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
596 LOG(WARNING) << optimizer->name() << " failed: " << message;
597 } else {
598 message = status.ToString();
599 LOG(ERROR) << optimizer->name() << " failed: " << message;
600 }
601 } else {
602 message = strings::StrCat(
603 PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
604 ", time = ", duration_ms, "ms.");
605 VLOG(1) << optimizer->name() << ": " << message;
606 }
607
608 // Swap function library back into the main graph.
609 if (!is_function_library_aware) {
610 optimized_graph->mutable_library()->Swap(&optimized_graph_function_library);
611 }
612
613 OptimizerResult optimizer_result{optimizer->name(), message, status};
614 optimization_result->results.push_back(optimizer_result);
615
616 if (!status.ok() && cfg_.fail_on_optimizer_errors()) return status;
617
618 return Status::OK();
619 }
620
621 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)622 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
623 FunctionDefLibrary& fdef_lib) {
624 // Collect functions that need the attribute in this set.
625 absl::flat_hash_set<std::string> tf_data_functions;
626 std::function<void(const std::string&)> collect_tf_data_functions_dfs =
627 [&](const std::string& func_name) -> void {
628 const FunctionDef* func_def = flib.Find(func_name);
629 // Skip functions that are not reachable from the optimized graph.
630 if (func_def == nullptr) return;
631
632 // Return if we already found and added this function.
633 if (tf_data_functions.contains(func_name)) return;
634
635 // We only get here if the function is (directly or indirectly) called from
636 // a tf.data function, so add it to the set.
637 tf_data_functions.insert(func_name);
638
639 // Proceed with DFS for functions called from current function.
640 for (const NodeDef& node : func_def->node_def()) {
641 if (flib.Contains(node.op())) {
642 // This is a function call node.
643 collect_tf_data_functions_dfs(node.op());
644 }
645 // Check if there are functions in attributes.
646 for (const auto& attr : node.attr()) {
647 const AttrValue& attr_value = attr.second;
648 if (attr_value.has_func()) {
649 collect_tf_data_functions_dfs(attr_value.func().name());
650 }
651 if (attr_value.has_list()) {
652 for (const auto& func : attr_value.list().func()) {
653 collect_tf_data_functions_dfs(func.name());
654 }
655 }
656 }
657 }
658 };
659 // Perform DFS for all tf.data functions in `fdef_lib`.
660 for (const auto& func_def : fdef_lib.function()) {
661 const std::string& func_name = func_def.signature().name();
662 if (data::IsTFDataFunction(func_def))
663 collect_tf_data_functions_dfs(func_name);
664 }
665 // Set attribute for tf.data functions. We cannot do this in the DFS directly
666 // because `FunctionLibraryDefinition` does not seem to provide mutable access
667 // to a `FunctionDef`.
668 for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
669 const std::string& func_name = func_def.signature().name();
670 if (tf_data_functions.contains(func_name) &&
671 !data::IsTFDataFunction(func_def)) {
672 VLOG(2) << "Marking " << func_name << " as tf.data function";
673 (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
674 }
675 }
676 }
677
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)678 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
679 GraphDef* optimized_graph) {
680 const uint64 start_us = Env::Default()->NowMicros();
681
682 VLOG(1) << "Starting optimization for grappler item: " << item.id;
683 optimization_results_.clear();
684
685 // Constructs a FunctionLibraryDefinition with functions that are reachable
686 // from the nodes of the graph.
687 const auto minimized_flib =
688 [](const GraphDef& graph) -> FunctionLibraryDefinition {
689 return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
690 .ReachableDefinitions(graph);
691 };
692
693 // 0. Original graph might contain a huge function library, that is mostly
694 // unused. This library copied over by each individual Grappler optimizer,
695 // which adds a huge overhead. Before starting optimization passes we just
696 // remove all the unreachable functions.
697 // TODO(ezhulenev): Construct reachable function library definition directly
698 // from the proto without constructing temporary FunctionLibraryDefinition.
699 int old_library_size = item.graph.library().function_size();
700 *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
701 int new_library_size = item.graph.library().function_size();
702
703 VLOG(1) << absl::Substitute(
704 "Deleted $0 unreachable functions from the graph (library size = $1)",
705 old_library_size - new_library_size, new_library_size);
706
707 // Save a few small fields from item before we move it.
708 bool optimize_function_library =
709 item.optimization_options().optimize_function_library;
710 const auto producer = item.graph.versions().producer();
711
712 // 1. Optimize main graph
713 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph));
714 VLOG(1) << "Optimized main graph.";
715 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
716
717 // 2. Optimize functions reachable from the optimized graph.
718 FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
719 using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
720
721 // Find functions for which we might need to compute a gradient at runtime.
722 absl::flat_hash_set<string> differentiable_functions;
723
724 const auto find_differentiable_functions =
725 [&](const NodeDefs& nodes) -> void {
726 for (const NodeDef& node : nodes) {
727 if (IsSymbolicGradient(node)) {
728 const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
729 if (f_attr) differentiable_functions.insert(f_attr->func().name());
730 }
731 }
732 };
733
734 // SymbolicGradient nodes inside the main graph.
735 find_differentiable_functions(optimized_graph->node());
736 // SymbolicGradient nodes inside the function library.
737 for (const FunctionDef& function : optimized_graph->library().function()) {
738 find_differentiable_functions(function.node_def());
739 }
740
741 // Find functions that will be compiled by XLA later
742 // We do it by looking for XlaLaunch ops that call functions,
743 // then depth first search down those functions to find transitive functions.
744 // Grappler rewrites can potentially add nodes that are
745 // not supported by XLA, so we choose to skip such functions when we optimize
746 // the function library.
747 absl::flat_hash_set<string> xla_compiled_functions;
748 std::function<void(const string&)> find_all_functions;
749 find_all_functions = [&](const string& func) -> void {
750 // Ignore call cycles in the graph
751 if (xla_compiled_functions.contains(func)) return;
752 // Find func in the flib
753 const FunctionDef* func_def = flib.Find(func);
754 CHECK(func_def) << "not found: " << func;
755 // Mark function to be ignored by grappler
756 xla_compiled_functions.insert(func);
757 // Depth first search through the func for transitively called funcs
758 for (const NodeDef& node : func_def->node_def()) {
759 for (const auto attr : node.attr()) {
760 const AttrValue& attr_value = attr.second;
761 if (attr_value.has_func()) {
762 find_all_functions(attr_value.func().name());
763 }
764 }
765 }
766 };
767
768 auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
769 NameAttrList function;
770 for (const NodeDef& node : nodes) {
771 // Look only for XlaLaunch nodes that call a function
772 if (!IsXlaLaunch(node)) continue;
773 if (!GetNodeAttr(node, "function", &function).ok()) continue;
774 // Find all transitively called functions
775 find_all_functions(function.name());
776 }
777 };
778
779 // XlaLaunch ops inside the main graph ...
780 find_xla_compiled_functions(optimized_graph->node());
781 // ... and inside the function library.
782 for (const FunctionDef& function : optimized_graph->library().function()) {
783 find_xla_compiled_functions(function.node_def());
784 }
785 // Propagate `_tf_data_function` attributes from functions to their callees.
786 PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
787
788 // Optimize each function only once.
789 absl::flat_hash_set<string> optimized_funcs;
790 while (optimize_function_library) {
791 optimize_function_library = false;
792
793 int function_idx = 0;
794 for (const FunctionDef& func : optimized_graph->library().function()) {
795 GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
796
797 const string& func_name = func.signature().name();
798
799 // Skip functions that are not reachable from the optimized graph.
800 if (!flib.Contains(func_name)) continue;
801 // Skip already optimized functions.
802 if (optimized_funcs.contains(func_name)) continue;
803 // Skip functions that will be compiled by XLA.
804 if (xla_compiled_functions.contains(func_name)) continue;
805
806 // Skip parametrized functions (function type or body is defined only at
807 // function call time by caller node attributes).
808 // They should be specialized to their instantiation type parameters by
809 // the function optimizer, before we can optimize function body.
810 if (IsParametrized(func)) continue;
811
812 // Skip tf.data functions as they are optimized by tf.data meta optimizer
813 // and in function instantiation.
814 if (data::IsTFDataFunction(func)) continue;
815
816 VLOG(3) << "Optimize function: function=" << func_name << " ["
817 << function_idx++ << " of "
818 << optimized_graph->library().function_size() << "]";
819
820 // Function optimization might specialize nested function calls, so we
821 // have to reset the flag and do at least one more pass over the library.
822 optimize_function_library = true;
823 optimized_funcs.insert(func_name);
824
825 // Make a GrapplerItem from a FunctionDef.
826 GrapplerFunctionItem func_item;
827 TF_RETURN_IF_ERROR(
828 MakeGrapplerFunctionItem(func, flib, producer, &func_item));
829
830 // If we need to compute the gradient of optimized function at runtime, we
831 // can't perform non-differentiable rewrites.
832 func_item.optimization_options().allow_non_differentiable_rewrites =
833 !differentiable_functions.contains(func_name);
834
835 // Device set available to the function is defined only by the runtime,
836 // when we instantiate and execute the function. We can't use all devices
837 // available to the main graph, because after partitioning the function
838 // call node might execute on a remote worker.
839 if (!func_item.devices().empty()) {
840 return errors::Internal("GrapplerFunctionItem devices must be empty.");
841 }
842
843 // We are not allowed to prune certain types of ops from the graph
844 // instantiated by the function definition, because we must guarantee
845 // function execution semantics wrt side effects (see
846 // function_optimizer.cc).
847 func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
848 false;
849
850 // Optimize function body graph.
851 GraphDef optimized_func_graph;
852 if (IsTPUGraphDef(*optimized_graph)) {
853 // Skip optimizing functions if this is a TPU graph. Currently, Grappler
854 // passes do not handle TPU functions correctly in a variety of ways
855 // (Note that due to the pre-placement TPU graph rewriting passes, the
856 // TPU-related ops are encapsulated away into functions). For example,
857 // TPU graphs contain TPUReplicateMetadata node that carries relevant
858 // TPU metadata and Grappler passes could prune that away. Grappler
859 // passes could also cause issues around shape inference. Since the
860 // desired and existing behavior is to not optimize TPU functions with
861 // Grappler, this check preserves that. The only exception is
862 // implementation selector what is required to swap in some TPU specific
863 // lowering code and is verified the work correctly on TPUs.
864 ImplementationSelector implementation_selector;
865
866 // Implementation selector needs to have access to valid function
867 // signature and attributes, and it doesn't need actual function body.
868 FunctionDefLibrary func_item_function_library;
869 func_item_function_library.Swap(func_item.graph.mutable_library());
870 *func_item.graph.mutable_library() =
871 GetFunctionDefLibraryStub(func_item_function_library);
872
873 TF_RETURN_IF_ERROR(implementation_selector.Optimize(
874 cluster, func_item, &optimized_func_graph));
875 } else {
876 GrapplerFunctionItem func_item_copy = func_item;
877 TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
878 &optimized_func_graph));
879 }
880
881 // Function body optimization might have created new specialized
882 // functions for each instantiation context. Add them to the library.
883 for (const FunctionDef& func_def :
884 optimized_func_graph.library().function()) {
885 if (flib.Find(func_def.signature().name()) == nullptr) {
886 TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
887 }
888 }
889
890 // Convert optimized graph back to FunctionDef.
891 FunctionDef optimized_func;
892 func_item.SwapFunctionBody(std::move(optimized_func_graph));
893 TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
894
895 // Replace optimized function with a new FunctionDef.
896 TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
897 }
898
899 // If optimized at least one function, update the graph library.
900 if (optimize_function_library) {
901 *optimized_graph->mutable_library() = flib.ToProto();
902 }
903 }
904
905 VLOG(1) << "Optimized " << optimized_funcs.size()
906 << " functions: " << absl::StrJoin(optimized_funcs, ", ");
907 VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
908 if (VLOG_IS_ON(1)) {
909 DumpGraphDefToFile(
910 strings::StrCat("after_MetaOptimizer_",
911 reinterpret_cast<uintptr_t>(optimized_graph)),
912 *optimized_graph);
913 }
914
915 const uint64 end_us = Env::Default()->NowMicros();
916 metrics::UpdateGrapplerPassTime("*", end_us - start_us);
917
918 return Status::OK();
919 }
920
GetResultString() const921 string MetaOptimizer::GetResultString() const {
922 std::string result_string;
923 for (const GraphOptimizationResult& graph_result : optimization_results_) {
924 absl::StrAppend(&result_string,
925 "Optimization results for grappler item: ", graph_result.id,
926 "\n");
927 for (const OptimizerResult& result : graph_result.results) {
928 absl::StrAppend(&result_string, " ", result.optimizer_name, ": ",
929 result.message, "\n");
930 }
931 }
932 return result_string;
933 }
934
PrintResult()935 void MetaOptimizer::PrintResult() { LOG(INFO) << GetResultString(); }
936
MetaOptimizerEnabled(const ConfigProto & cfg)937 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
938 const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
939 if (rewrite_cfg.disable_meta_optimizer()) {
940 return false;
941 }
942 return !rewrite_cfg.disable_model_pruning() ||
943 rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
944 rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
945 rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
946 rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
947 rewrite_cfg.remapping() != RewriterConfig::OFF ||
948 rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
949 rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
950 rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
951 rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
952 rewrite_cfg.auto_parallel().enable() ||
953 rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
954 rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
955 rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
956 rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
957 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
958 AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
959 !rewrite_cfg.optimizers().empty() ||
960 !rewrite_cfg.custom_optimizers().empty();
961 }
962
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)963 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
964 DeviceBase* cpu_device, Cluster* cluster,
965 GraphDef* optimized_graph) {
966 MetaOptimizer optimizer(cpu_device, cfg);
967 optimizer.set_deadline_usec(
968 DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
969 return optimizer.OptimizeConsumeItem(cluster, std::move(item),
970 optimized_graph);
971 }
972
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)973 Status OptimizeGraph(
974 std::vector<string> ret_node_names, std::vector<string> keep_node_names,
975 FunctionLibraryDefinition* flib, const DeviceSet& device_set,
976 Device* cpu_device, const ConfigProto& config_proto,
977 const string& grappler_item_id,
978 const GrapplerItem::OptimizationOptions& optimization_options,
979 std::unique_ptr<tensorflow::Graph>* g) {
980 if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
981 return Status::OK();
982 }
983
984 tensorflow::grappler::GrapplerItem item;
985 item.id = grappler_item_id;
986 item.optimization_options() = optimization_options;
987
988 // Add all available devices so that inlined function can be placed.
989 for (const Device* d : device_set.devices()) {
990 Status added_device = item.AddDevice(d->name());
991 if (!added_device.ok()) VLOG(3) << added_device.error_message();
992 }
993 VLOG(3) << "Grappler available devices: "
994 << absl::StrJoin(item.devices(), ", ");
995
996 // Add fetches so that the graph can be pruned.
997 item.fetch.swap(ret_node_names);
998
999 // Add noes that can't be removed from the graph.
1000 item.keep_ops = std::move(keep_node_names);
1001
1002 (*g)->ToGraphDef(&item.graph);
1003
1004 if (flib) {
1005 *item.graph.mutable_library() = flib->ToProto();
1006 }
1007
1008 tensorflow::GraphDef out_graph;
1009 tensorflow::grappler::VirtualCluster cluster(&device_set);
1010 // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1011 // proto (which also contain the OptimizerOptions).
1012 TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1013 std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1014
1015 std::unique_ptr<tensorflow::Graph> optimized_graph(
1016 new tensorflow::Graph(OpRegistry::Global()));
1017
1018 // Copy optimized functions back to the overlay lib.
1019 if (flib) {
1020 for (const FunctionDef& fdef : out_graph.library().function()) {
1021 const string& func_name = fdef.signature().name();
1022 if (flib->Contains(func_name)) {
1023 TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, fdef));
1024 } else {
1025 TF_RETURN_IF_ERROR(flib->AddFunctionDef(fdef));
1026 }
1027 }
1028 }
1029
1030 TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1031 GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1032
1033 // The graph conversion sets the requested device names but not the
1034 // assigned device names. However, since at this point the graph is
1035 // placed TF expects an assigned device name for every node. Therefore
1036 // we copy the requested device into the assigned device field.
1037 for (Node* node : optimized_graph->nodes()) {
1038 if (node->IsOp() && node->assigned_device_name().empty()) {
1039 if (node->requested_device().empty()) {
1040 return errors::Internal(
1041 "Either placer did not place the node or Grappler did not "
1042 "copy the assigned device. Contact Grappler team since latter "
1043 "is more likely. Node=",
1044 node->name(),
1045 " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1046 }
1047 node->set_assigned_device_name(node->requested_device());
1048 }
1049 }
1050
1051 *g = std::move(optimized_graph);
1052 return Status::OK();
1053 }
1054
1055 } // namespace grappler
1056 } // namespace tensorflow
1057