1 /* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 
16 #include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
17 
18 #include "absl/strings/str_cat.h"
19 #include "absl/strings/str_join.h"
20 #include "absl/strings/substitute.h"
21 #include "tensorflow/core/common_runtime/function.h"
22 #include "tensorflow/core/common_runtime/graph_constructor.h"
23 #include "tensorflow/core/common_runtime/metrics.h"
24 #include "tensorflow/core/framework/dataset.h"
25 #include "tensorflow/core/framework/function.pb.h"
26 #include "tensorflow/core/framework/tensor_shape.pb.h"
27 #include "tensorflow/core/framework/tensor_util.h"
28 #include "tensorflow/core/framework/versions.pb.h"
29 #include "tensorflow/core/grappler/clusters/virtual_cluster.h"
30 #include "tensorflow/core/grappler/optimizers/arithmetic_optimizer.h"
31 #include "tensorflow/core/grappler/optimizers/auto_mixed_precision.h"
32 #include "tensorflow/core/grappler/optimizers/auto_parallel.h"
33 #include "tensorflow/core/grappler/optimizers/common_subgraph_elimination.h"
34 #include "tensorflow/core/grappler/optimizers/constant_folding.h"
35 #include "tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.h"
36 #include "tensorflow/core/grappler/optimizers/debug_stripper.h"
37 #include "tensorflow/core/grappler/optimizers/dependency_optimizer.h"
38 #include "tensorflow/core/grappler/optimizers/function_optimizer.h"
39 #include "tensorflow/core/grappler/optimizers/generic_layout_optimizer.h"
40 #include "tensorflow/core/grappler/optimizers/implementation_selector.h"
41 #include "tensorflow/core/grappler/optimizers/loop_optimizer.h"
42 #include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
43 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
44 #include "tensorflow/core/grappler/optimizers/pin_to_host_optimizer.h"
45 #include "tensorflow/core/grappler/optimizers/remapper.h"
46 #include "tensorflow/core/grappler/optimizers/scoped_allocator_optimizer.h"
47 #include "tensorflow/core/grappler/optimizers/shape_optimizer.h"
48 #include "tensorflow/core/grappler/utils/canonicalizer.h"
49 #include "tensorflow/core/grappler/utils/colocation.h"
50 #include "tensorflow/core/grappler/utils/functions.h"
51 #include "tensorflow/core/grappler/utils/topological_sort.h"
52 #include "tensorflow/core/grappler/utils/tpu.h"
53 #include "tensorflow/core/grappler/verifiers/structure_verifier.h"
54 #include "tensorflow/core/lib/core/status.h"
55 #include "tensorflow/core/lib/gtl/map_util.h"
56 #include "tensorflow/core/util/dump_graph.h"
57 #include "tensorflow/core/util/ptr_util.h"
58 #include "tensorflow/core/util/xla_config_registry.h"
59 
60 namespace tensorflow {
61 namespace grappler {
62 
63 namespace {
64 
65 constexpr int kDefaultNumberOfIterations = 2;
66 constexpr int kDefaultMinGraphNodes = 4;
67 
NumEdges(const GraphDef & graph)68 int64 NumEdges(const GraphDef& graph) {
69   int64 num_edges = 0;
70   for (const auto& node : graph.node()) {
71     num_edges += node.input_size();
72   }
73   return num_edges;
74 }
75 
PrintSizesBeforeAfter(const GraphDef & before,const GraphDef & after)76 string PrintSizesBeforeAfter(const GraphDef& before, const GraphDef& after) {
77   return strings::StrCat("Graph size after: ", after.node_size(), " nodes (",
78                          after.node_size() - before.node_size(), "), ",
79                          NumEdges(after), " edges (",
80                          NumEdges(after) - NumEdges(before), ")");
81 }
82 
NumIterations(const RewriterConfig & cfg)83 int NumIterations(const RewriterConfig& cfg) {
84   return cfg.meta_optimizer_iterations() == RewriterConfig::DEFAULT_NUM_ITERS
85              ? kDefaultNumberOfIterations
86              : cfg.meta_optimizer_iterations();
87 }
88 
89 // Check if optimizer is allowed to run only once.
IsRunOnceOptimizer(const string & name)90 bool IsRunOnceOptimizer(const string& name) {
91   return name == "layout" || name == "memory_optimizer" ||
92          name == "loop_optimizer" || name == "auto_mixed_precision" ||
93          name == "auto_mixed_precision_mkl";
94 }
95 
96 // Creates a function library stub from a real function library: copy only
97 // signatures and attributes of all the function defined in fdef_lib. This stub
98 // can be swapped with real function library in a graph, before passing it to
99 // optimizer, if optimizer doesn't instantiate functions.
GetFunctionDefLibraryStub(const FunctionDefLibrary & fdef_lib)100 FunctionDefLibrary GetFunctionDefLibraryStub(
101     const FunctionDefLibrary& fdef_lib) {
102   FunctionDefLibrary stub;
103   for (const FunctionDef& fn : fdef_lib.function()) {
104     FunctionDef* fn_stub = stub.mutable_function()->Add();
105     *(fn_stub->mutable_signature()) = fn.signature();
106     *(fn_stub->mutable_attr()) = fn.attr();
107     *(fn_stub->mutable_arg_attr()) = fn.arg_attr();
108     *(fn_stub->mutable_resource_arg_unique_id()) = fn.resource_arg_unique_id();
109   }
110   *stub.mutable_gradient() = fdef_lib.gradient();
111   return stub;
112 }
113 
DeadlineMicroSeconds(const RewriterConfig & cfg)114 uint64 DeadlineMicroSeconds(const RewriterConfig& cfg) {
115   const uint64 kTwentyMinutesInUsec = 20 * 60 * 1000 * 1000;
116   if (cfg.meta_optimizer_timeout_ms() < 0) {
117     return 0;
118   } else {
119     return cfg.meta_optimizer_timeout_ms() == 0
120                ? Env::Default()->NowMicros() + kTwentyMinutesInUsec
121                : Env::Default()->NowMicros() +
122                      cfg.meta_optimizer_timeout_ms() * 1000;
123   }
124 }
125 
126 // A helper function to decide whether to enable the automatic mixed precision
127 // optimizer.
AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level)128 bool AutoMixedPrecisionEnabled(RewriterConfig::Toggle opt_level) {
129   if (opt_level == RewriterConfig::ON ||
130       opt_level == RewriterConfig::AGGRESSIVE) {
131     return true;
132   }
133   return false;
134 }
135 
IsXlaGlobalJitOn(const OptimizerOptions::GlobalJitLevel & jit_level_in_session_opts)136 bool IsXlaGlobalJitOn(
137     const OptimizerOptions::GlobalJitLevel& jit_level_in_session_opts) {
138   xla_config_registry::XlaGlobalJitLevel xla_global_jit_level =
139       xla_config_registry::GetGlobalJitLevel(jit_level_in_session_opts);
140   // Return true only if XLA JIT is ON for both single-gpu and multi-gpu
141   // graphs. This is a conservative approach that turns off the memory optimizer
142   // when we are sure that all graphs will be processed by XLA JIT.
143   bool is_on = (xla_global_jit_level.single_gpu == OptimizerOptions::ON_1 ||
144                 xla_global_jit_level.single_gpu == OptimizerOptions::ON_2) &&
145                (xla_global_jit_level.general == OptimizerOptions::ON_1 ||
146                 xla_global_jit_level.general == OptimizerOptions::ON_2);
147   return is_on;
148 }
149 
150 // A helper function to decide whether to enable the memory optimizer.
MemoryOptimizerEnabled(RewriterConfig::MemOptType mem_opt_type,OptimizerOptions::GlobalJitLevel jit_level_in_session_opts)151 bool MemoryOptimizerEnabled(
152     RewriterConfig::MemOptType mem_opt_type,
153     OptimizerOptions::GlobalJitLevel jit_level_in_session_opts) {
154   // Disable the default memory optimizer when XLA JIT is ON as it hurts the
155   // XLA JIT performance. The (current) XLA clustering can result in loss of
156   // concurrency between kernel compute and memory copies. As such, it usually
157   // loses the concurrency needed to hide the latencies of the inserted swap-ins
158   // and swap-outs and incurs great performance overhead. Remove this check when
159   // the XLA JIT can better deal with the concurrency.
160   if (mem_opt_type == RewriterConfig::DEFAULT_MEM_OPT &&
161       IsXlaGlobalJitOn(jit_level_in_session_opts)) {
162     return false;
163   }
164 
165   return mem_opt_type != RewriterConfig::NO_MEM_OPT;
166 }
167 
168 }  // namespace
169 
170 #define MK_OPT(NAME, VALUE) \
171   if (optimizer == NAME) return std::unique_ptr<GraphOptimizer>(VALUE)
172 
LowerControlFlow() const173 bool MetaOptimizer::LowerControlFlow() const {
174   if (config_proto_.experimental().executor_type() ==
175       "SINGLE_THREADED_EXECUTOR")
176     return false;
177 
178   if (config_proto_.experimental().use_tfrt()) return false;
179 
180   return true;
181 }
182 
MakeNewOptimizer(const string & optimizer) const183 std::unique_ptr<GraphOptimizer> MetaOptimizer::MakeNewOptimizer(
184     const string& optimizer) const {
185   MK_OPT("pruning", new ModelPruner());
186   MK_OPT("function",
187          new FunctionOptimizer(cfg_.function_optimization(),
188                                /*lower_control_flow=*/LowerControlFlow()));
189   MK_OPT("constfold",
190          new ConstantFolding(
191              cpu_device_,
192              cfg_.experimental_disable_compressed_tensor_optimization(),
193              !cfg_.experimental_disable_folding_quantization_emulation()));
194   MK_OPT("shape", new ShapeOptimizer());
195   MK_OPT("remap", new Remapper(cfg_.remapping()));
196   MK_OPT("layout", new GenericLayoutOptimizer(
197                        /*optimization level*/ cfg_.layout_optimizer(),
198                        /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
199   MK_OPT("auto_mixed_precision",
200          new AutoMixedPrecision(AutoMixedPrecisionMode::CUDA));
201   MK_OPT("auto_mixed_precision_mkl",
202          new AutoMixedPrecision(AutoMixedPrecisionMode::MKL));
203   MK_OPT("memory", new MemoryOptimizer(RewriterConfig::MANUAL));
204   MK_OPT("common_subgraph_elimination",
205          new CommonSubgraphElimination(cfg_.common_subgraph_elimination()));
206   MK_OPT("arithmetic", new ArithmeticOptimizer(cfg_.arithmetic_optimization()));
207   MK_OPT("autoparallel", new AutoParallel(cfg_.auto_parallel().num_replicas()));
208   MK_OPT("loop", new LoopOptimizer(cfg_.loop_optimization(), cpu_device_));
209   MK_OPT("dependency", new DependencyOptimizer(cfg_.dependency_optimization()));
210   MK_OPT("debug_stripper", new DebugStripper());
211   MK_OPT("scoped_allocator",
212          new ScopedAllocatorOptimizer(cfg_.scoped_allocator_optimization(),
213                                       cfg_.scoped_allocator_opts()));
214   MK_OPT("pin_to_host",
215          new PinToHostOptimizer(cfg_.pin_to_host_optimization()));
216 
217   return std::unique_ptr<GraphOptimizer>();
218 }
219 
220 #undef MK_OPT
221 
MetaOptimizer(DeviceBase * cpu_device,const ConfigProto & cfg)222 MetaOptimizer::MetaOptimizer(DeviceBase* cpu_device, const ConfigProto& cfg)
223     : cpu_device_(cpu_device),
224       config_proto_(cfg),
225       cfg_(*config_proto_.mutable_graph_options()->mutable_rewrite_options()) {
226   DCHECK(cpu_device_ == nullptr ||
227          cpu_device_->attributes().device_type() == "CPU");
228 }
229 
InitializeOptimizers(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const230 Status MetaOptimizer::InitializeOptimizers(
231     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
232   if (cfg_.disable_meta_optimizer()) {
233     return Status::OK();
234   }
235   if (!cfg_.disable_model_pruning()) {
236     optimizers->push_back(MakeUnique<ModelPruner>());
237   }
238   if (cfg_.implementation_selector() != RewriterConfig::OFF) {
239     optimizers->push_back(MakeUnique<ImplementationSelector>());
240   }
241   if (cfg_.function_optimization() != RewriterConfig::OFF) {
242     optimizers->push_back(MakeUnique<FunctionOptimizer>(
243         cfg_.function_optimization(),
244         /*lower_control_flow=*/LowerControlFlow()));
245   }
246   if (cfg_.common_subgraph_elimination() != RewriterConfig::OFF &&
247       cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
248     optimizers->push_back(MakeUnique<CommonSubgraphElimination>(
249         cfg_.common_subgraph_elimination()));
250   }
251   if (cfg_.debug_stripper() == RewriterConfig::ON) {
252     optimizers->push_back(MakeUnique<DebugStripper>());
253   }
254   if (cfg_.constant_folding() != RewriterConfig::OFF) {
255     optimizers->push_back(MakeUnique<ConstantFolding>(
256         cfg_.constant_folding(), cpu_device_,
257         cfg_.experimental_disable_compressed_tensor_optimization(),
258         !cfg_.experimental_disable_folding_quantization_emulation()));
259   }
260   if (cfg_.shape_optimization() != RewriterConfig::OFF) {
261     optimizers->push_back(MakeUnique<ShapeOptimizer>());
262   }
263   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision())) {
264     optimizers->push_back(
265         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::CUDA));
266   }
267   if (AutoMixedPrecisionEnabled(cfg_.auto_mixed_precision_mkl())) {
268     optimizers->push_back(
269         MakeUnique<AutoMixedPrecision>(AutoMixedPrecisionMode::MKL));
270   }
271   if (cfg_.pin_to_host_optimization() == RewriterConfig::ON) {
272     optimizers->push_back(MakeUnique<PinToHostOptimizer>());
273   }
274   if (cfg_.arithmetic_optimization() != RewriterConfig::OFF) {
275     optimizers->push_back(
276         MakeUnique<ArithmeticOptimizer>(cfg_.arithmetic_optimization()));
277   }
278   if (cfg_.layout_optimizer() != RewriterConfig::OFF) {
279     optimizers->push_back(MakeUnique<GenericLayoutOptimizer>(
280         /*optimization level*/ cfg_.layout_optimizer(),
281         /*CPU layout conversion*/ cfg_.cpu_layout_conversion()));
282   }
283   if (cfg_.remapping() != RewriterConfig::OFF) {
284     optimizers->push_back(MakeUnique<Remapper>(cfg_.remapping()));
285   }
286   if (cfg_.loop_optimization() != RewriterConfig::OFF) {
287     optimizers->push_back(
288         MakeUnique<LoopOptimizer>(cfg_.loop_optimization(), cpu_device_));
289   }
290   if (cfg_.dependency_optimization() != RewriterConfig::OFF) {
291     optimizers->push_back(
292         MakeUnique<DependencyOptimizer>(cfg_.dependency_optimization()));
293   }
294   auto global_jit_level =
295       config_proto_.graph_options().optimizer_options().global_jit_level();
296   if (MemoryOptimizerEnabled(cfg_.memory_optimization(), global_jit_level)) {
297     if (cfg_.memory_optimizer_target_node_name_scope().empty()) {
298       optimizers->push_back(
299           // Use the default target node name prefix "gradients/"
300           MakeUnique<MemoryOptimizer>(cfg_.memory_optimization()));
301     } else {
302       optimizers->push_back(MakeUnique<MemoryOptimizer>(
303           cfg_.memory_optimization(),
304           cfg_.memory_optimizer_target_node_name_scope()));
305     }
306   }
307   if (cfg_.auto_parallel().enable()) {
308     optimizers->push_back(
309         MakeUnique<AutoParallel>(cfg_.auto_parallel().num_replicas()));
310   }
311   if (cfg_.scoped_allocator_optimization()) {
312     optimizers->push_back(MakeUnique<ScopedAllocatorOptimizer>(
313         cfg_.scoped_allocator_optimization(), cfg_.scoped_allocator_opts()));
314   }
315   return InitializeCustomGraphOptimizers(std::set<string>(), optimizers);
316 }
317 
InitializeOptimizersByName(std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const318 Status MetaOptimizer::InitializeOptimizersByName(
319     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
320   std::set<string> initialized_custom_optimizers;
321   for (const string& optimizer_name : cfg_.optimizers()) {
322     auto optimizer = MakeNewOptimizer(optimizer_name);
323     if (optimizer) {
324       VLOG(2) << "Registered default graph optimizer: " << optimizer_name;
325       optimizers->push_back(std::move(optimizer));
326       continue;
327     }
328 
329     auto custom_optimizer =
330         CustomGraphOptimizerRegistry::CreateByNameOrNull(optimizer_name);
331 
332     if (custom_optimizer) {
333       VLOG(2) << "Registered custom graph optimizer: " << optimizer_name;
334       TF_RETURN_IF_ERROR(custom_optimizer->InitWithConfig(
335           config_proto_, GetCustomGraphOptimizerConfig(optimizer_name)));
336       optimizers->push_back(std::move(custom_optimizer));
337       initialized_custom_optimizers.insert(optimizer_name);
338     } else {
339       VLOG(2) << "Can't register an optimizer by name: " << optimizer_name;
340     }
341   }
342   return InitializeCustomGraphOptimizers(initialized_custom_optimizers,
343                                          optimizers);
344 }
345 
InitializeCustomGraphOptimizers(const std::set<string> & pre_initialized_optimizers,std::vector<std::unique_ptr<GraphOptimizer>> * optimizers) const346 Status MetaOptimizer::InitializeCustomGraphOptimizers(
347     const std::set<string>& pre_initialized_optimizers,
348     std::vector<std::unique_ptr<GraphOptimizer>>* optimizers) const {
349   for (const auto& optimizer_config : cfg_.custom_optimizers()) {
350     if (pre_initialized_optimizers.find(optimizer_config.name()) !=
351         pre_initialized_optimizers.end()) {
352       continue;
353     }
354 
355     auto custom_optimizer = CustomGraphOptimizerRegistry::CreateByNameOrNull(
356         optimizer_config.name());
357 
358     if (custom_optimizer) {
359       VLOG(2) << "Registered custom configurable graph optimizer: "
360               << optimizer_config.name();
361       TF_RETURN_IF_ERROR(
362           custom_optimizer->InitWithConfig(config_proto_, &optimizer_config));
363       optimizers->push_back(std::move(custom_optimizer));
364     } else {
365       // If there are no custom optimizers with given name, try to initialize a
366       // default optimizer. This way, custom configurable optimizers can be
367       // mixed with default optimizers in any order.
368       auto optimizer = MakeNewOptimizer(optimizer_config.name());
369       if (optimizer) {
370         VLOG(2) << "Registered default graph optimizer: "
371                 << optimizer_config.name();
372         optimizers->push_back(std::move(optimizer));
373         continue;
374       }
375       VLOG(2) << "Can't register an optimizer by name: "
376               << optimizer_config.name();
377     }
378   }
379   return Status::OK();
380 }
381 
382 const RewriterConfig::CustomGraphOptimizer*
GetCustomGraphOptimizerConfig(const string & name) const383 MetaOptimizer::GetCustomGraphOptimizerConfig(const string& name) const {
384   for (const auto& config : cfg_.custom_optimizers()) {
385     if (config.name() == name) {
386       return &config;
387     }
388   }
389   return nullptr;
390 }
391 
InitializeVerifiers(std::vector<std::unique_ptr<GraphVerifier>> * inter_optimizer_verifiers,std::vector<std::unique_ptr<GraphVerifier>> * post_optimization_verifiers) const392 void MetaOptimizer::InitializeVerifiers(
393     std::vector<std::unique_ptr<GraphVerifier>>* inter_optimizer_verifiers,
394     std::vector<std::unique_ptr<GraphVerifier>>* post_optimization_verifiers)
395     const {
396   if (cfg_.inter_optimizer_verifier_config().structure_verifier() ==
397       VerifierConfig::ON) {
398     inter_optimizer_verifiers->push_back(MakeUnique<StructureVerifier>());
399   }
400   if (cfg_.post_optimization_verifier_config().structure_verifier() ==
401       VerifierConfig::ON) {
402     post_optimization_verifiers->push_back(MakeUnique<StructureVerifier>());
403   }
404 }
405 
OptimizeGraph(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)406 Status MetaOptimizer::OptimizeGraph(Cluster* cluster, GrapplerItem&& item,
407                                     GraphDef* optimized_graph) {
408   int min_graph_nodes = cfg_.min_graph_nodes() == 0 ? kDefaultMinGraphNodes
409                                                     : cfg_.min_graph_nodes();
410   if (item.graph.node_size() < min_graph_nodes) {
411     VLOG(3) << "Skipping optimization, graph has less than " << min_graph_nodes
412             << " nodes.";
413     *optimized_graph = item.graph;
414     return Status::OK();
415   }
416 
417   const uint64 start_us = Env::Default()->NowMicros();
418 
419   std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
420   if (cfg_.optimizers().empty()) {
421     TF_RETURN_IF_ERROR(InitializeOptimizers(&optimizers));
422   } else {
423     TF_RETURN_IF_ERROR(InitializeOptimizersByName(&optimizers));
424   }
425 
426   // Initialize the configured verifiers.
427   std::vector<std::unique_ptr<GraphVerifier>> inter_optimizer_verifiers;
428   std::vector<std::unique_ptr<GraphVerifier>> post_optimization_verifiers;
429   InitializeVerifiers(&inter_optimizer_verifiers, &post_optimization_verifiers);
430   if (inter_optimizer_verifiers.empty()) {
431     VLOG(2) << "No inter optimizer verifiers have been configured";
432   } else {
433     VLOG(2) << inter_optimizer_verifiers.size()
434             << " inter optimizer verifiers have been configured";
435   }
436   if (post_optimization_verifiers.empty()) {
437     VLOG(2) << "No post optimization verifiers have been configured";
438   } else {
439     VLOG(2) << post_optimization_verifiers.size()
440             << " post optimization verifiers have been configured";
441   }
442 
443   VLOG(2) << "Optimize GrapplerItem: item.id=" << item.id
444           << " num_optimizers=" << optimizers.size()
445           << ", num nodes = " << item.graph.node_size();
446 
447   if (optimizers.empty()) {
448     VLOG(3) << "Skipping graph optimization, no optimizers registered";
449     *optimized_graph = item.graph;
450     return Status::OK();
451   }
452 
453   // Invariant: optimized_graph contains the most recently optimized version of
454   // the graph.
455   auto original_producer = item.graph.versions().producer();
456   optimized_graph->Swap(&item.graph);
457 
458   GraphOptimizationResult optimization_result(item.id);
459   GraphOptimizer* sa_optimizer = nullptr;
460 
461   // Constants in the graph are normally compressed after model_pruner.
462   // Do it here if model pruner is disabled.
463   if (cfg_.disable_model_pruning()) {
464     CompressConstants(optimized_graph);
465   }
466 
467   for (int iteration = 0; iteration < NumIterations(cfg_); ++iteration) {
468     // Don't bother optimizing further if the graph is already tiny.
469     if (optimized_graph->node_size() < min_graph_nodes) {
470       VLOG(3) << "Stopping after iteration " << iteration
471               << ", graph is tiny (#nodes = " << optimized_graph->node_size()
472               << "  < " << min_graph_nodes << ")";
473       break;
474     }
475 
476     VLOG(4) << "Starting optimization iteration " << iteration;
477     if (VLOG_IS_ON(4)) {
478       DumpGraphDefToFile(
479           strings::StrCat("before_MetaOptimizer_iteration_", iteration, "_",
480                           reinterpret_cast<uintptr_t>(optimized_graph)),
481           *optimized_graph);
482     }
483 
484     for (const auto& optimizer : optimizers) {
485       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
486       // Some optimizers can run only once.
487       if (iteration > 0 && IsRunOnceOptimizer(optimizer->name())) continue;
488       // Some must run only on the last iteration.
489       if (optimizer->name() == "scoped_allocator_optimizer") {
490         if (sa_optimizer == nullptr) sa_optimizer = optimizer.get();
491         continue;
492       }
493 
494       TF_RETURN_IF_ERROR(RunOptimizer(optimizer.get(), cluster, &item,
495                                       optimized_graph, &optimization_result));
496 
497       if (iteration == 0 && optimizer->name() == "model_pruner") {
498         CompressConstants(optimized_graph);
499       }
500 
501       if (VLOG_IS_ON(4)) {
502         DumpGraphDefToFile(
503             strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
504                             optimizer->name(), "_",
505                             reinterpret_cast<uintptr_t>(optimized_graph)),
506             *optimized_graph);
507       }
508       for (const auto& verifier : inter_optimizer_verifiers) {
509         // TODO(ashwinm): Need to enforce verification_deadline.
510         TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
511       }
512     }
513     if (VLOG_IS_ON(4)) {
514       DumpGraphDefToFile(
515           strings::StrCat("after_MetaOptimizer_iteration_", iteration, "_",
516                           reinterpret_cast<uintptr_t>(optimized_graph)),
517           *optimized_graph);
518     }
519     // TODO(ashwinm): Need to enforce verification_deadline.
520     for (const auto& verifier : post_optimization_verifiers) {
521       TF_RETURN_IF_ERROR(verifier->Verify(*optimized_graph));
522     }
523   }
524 
525   // ScopedAllocatorOptimizer must run last.
526   if (sa_optimizer != nullptr) {
527     TF_RETURN_IF_ERROR(RunOptimizer(sa_optimizer, cluster, &item,
528                                     optimized_graph, &optimization_result));
529     GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
530   }
531 
532   bool is_optimized = std::find_if(optimization_result.results.begin(),
533                                    optimization_result.results.end(),
534                                    [](const OptimizerResult& result) {
535                                      return result.status.ok();
536                                    }) != optimization_result.results.end();
537 
538   // Record graph optimization result.
539   optimization_results_.push_back(optimization_result);
540 
541   if (is_optimized) {
542     TF_RETURN_IF_ERROR(TopologicalSort(optimized_graph));
543     ReassignColocation(optimized_graph);
544     // Make sure that the optimizers preserved the graph version.
545     DCHECK_EQ(optimized_graph->versions().producer(), original_producer);
546   }
547 
548   const uint64 end_us = Env::Default()->NowMicros();
549   metrics::UpdateGrapplerPassTime("OptimizeMainGraph", end_us - start_us);
550 
551   return Status::OK();
552 }
553 
RunOptimizer(GraphOptimizer * optimizer,Cluster * cluster,GrapplerItem * optimized_item,GraphDef * optimized_graph,GraphOptimizationResult * optimization_result)554 Status MetaOptimizer::RunOptimizer(
555     GraphOptimizer* optimizer, Cluster* cluster, GrapplerItem* optimized_item,
556     GraphDef* optimized_graph, GraphOptimizationResult* optimization_result) {
557   const uint64 start_us = Env::Default()->NowMicros();
558 
559   // If optimizer doesn't need a function library, we will replace it with a
560   // stub before running optimization, and will put it back at the end.
561   FunctionDefLibrary optimized_graph_function_library;
562   const bool is_function_library_aware = optimizer->UsesFunctionLibrary();
563 
564   // Replace function library in optimized graph with a stub.
565   if (!is_function_library_aware) {
566     VLOG(3) << "Replace function library with a stub for " << optimizer->name();
567     optimized_graph_function_library.Swap(optimized_graph->mutable_library());
568     *optimized_graph->mutable_library() =
569         GetFunctionDefLibraryStub(optimized_graph_function_library);
570   }
571 
572   // This swaps the current optimized_graph into optimized item and
573   // resets optimized_graph to an empty graph.
574   optimized_graph->Swap(&optimized_item->graph);
575   *optimized_graph = GraphDef();
576   optimizer->set_deadline_usec(this->deadline_usec());
577   Status status =
578       optimizer->Optimize(cluster, *optimized_item, optimized_graph);
579   const uint64 end_us = Env::Default()->NowMicros();
580   const float duration_ms = (end_us - start_us) / 1000.0f;
581   metrics::UpdateGrapplerPassTime(optimizer->name(), end_us - start_us);
582 
583   string message;
584   if (!status.ok()) {
585     optimized_graph->Swap(&optimized_item->graph);
586     if (errors::IsAborted(status)) {
587       // By convention we (ab-)use the Aborted error code to signal that the
588       // optimizer returned without performing any changes to the graph.
589       message = strings::StrCat(optimizer->name(),
590                                 " did nothing. time = ", duration_ms, "ms.");
591       // Swallow the non-critical error.
592       status = Status::OK();
593     } else if (errors::IsDeadlineExceeded(status)) {
594       message =
595           strings::StrCat(status.ToString(), ", time = ", duration_ms, "ms.");
596       LOG(WARNING) << optimizer->name() << " failed: " << message;
597     } else {
598       message = status.ToString();
599       LOG(ERROR) << optimizer->name() << " failed: " << message;
600     }
601   } else {
602     message = strings::StrCat(
603         PrintSizesBeforeAfter(optimized_item->graph, *optimized_graph),
604         ", time = ", duration_ms, "ms.");
605     VLOG(1) << optimizer->name() << ": " << message;
606   }
607 
608   // Swap function library back into the main graph.
609   if (!is_function_library_aware) {
610     optimized_graph->mutable_library()->Swap(&optimized_graph_function_library);
611   }
612 
613   OptimizerResult optimizer_result{optimizer->name(), message, status};
614   optimization_result->results.push_back(optimizer_result);
615 
616   if (!status.ok() && cfg_.fail_on_optimizer_errors()) return status;
617 
618   return Status::OK();
619 }
620 
621 // Propagates `_tf_data_function` attributes from functions to their callees.
PropagateTFDataAttrs(const FunctionLibraryDefinition & flib,FunctionDefLibrary & fdef_lib)622 void PropagateTFDataAttrs(const FunctionLibraryDefinition& flib,
623                           FunctionDefLibrary& fdef_lib) {
624   // Collect functions that need the attribute in this set.
625   absl::flat_hash_set<std::string> tf_data_functions;
626   std::function<void(const std::string&)> collect_tf_data_functions_dfs =
627       [&](const std::string& func_name) -> void {
628     const FunctionDef* func_def = flib.Find(func_name);
629     // Skip functions that are not reachable from the optimized graph.
630     if (func_def == nullptr) return;
631 
632     // Return if we already found and added this function.
633     if (tf_data_functions.contains(func_name)) return;
634 
635     // We only get here if the function is (directly or indirectly) called from
636     // a tf.data function, so add it to the set.
637     tf_data_functions.insert(func_name);
638 
639     // Proceed with DFS for functions called from current function.
640     for (const NodeDef& node : func_def->node_def()) {
641       if (flib.Contains(node.op())) {
642         // This is a function call node.
643         collect_tf_data_functions_dfs(node.op());
644       }
645       // Check if there are functions in attributes.
646       for (const auto& attr : node.attr()) {
647         const AttrValue& attr_value = attr.second;
648         if (attr_value.has_func()) {
649           collect_tf_data_functions_dfs(attr_value.func().name());
650         }
651         if (attr_value.has_list()) {
652           for (const auto& func : attr_value.list().func()) {
653             collect_tf_data_functions_dfs(func.name());
654           }
655         }
656       }
657     }
658   };
659   // Perform DFS for all tf.data functions in `fdef_lib`.
660   for (const auto& func_def : fdef_lib.function()) {
661     const std::string& func_name = func_def.signature().name();
662     if (data::IsTFDataFunction(func_def))
663       collect_tf_data_functions_dfs(func_name);
664   }
665   // Set attribute for tf.data functions. We cannot do this in the DFS directly
666   // because `FunctionLibraryDefinition` does not seem to provide mutable access
667   // to a `FunctionDef`.
668   for (FunctionDef& func_def : *fdef_lib.mutable_function()) {
669     const std::string& func_name = func_def.signature().name();
670     if (tf_data_functions.contains(func_name) &&
671         !data::IsTFDataFunction(func_def)) {
672       VLOG(2) << "Marking " << func_name << " as tf.data function";
673       (*func_def.mutable_attr())[data::kTFDataFunction].set_b(true);
674     }
675   }
676 }
677 
OptimizeConsumeItem(Cluster * cluster,GrapplerItem && item,GraphDef * optimized_graph)678 Status MetaOptimizer::OptimizeConsumeItem(Cluster* cluster, GrapplerItem&& item,
679                                           GraphDef* optimized_graph) {
680   const uint64 start_us = Env::Default()->NowMicros();
681 
682   VLOG(1) << "Starting optimization for grappler item: " << item.id;
683   optimization_results_.clear();
684 
685   // Constructs a FunctionLibraryDefinition with functions that are reachable
686   // from the nodes of the graph.
687   const auto minimized_flib =
688       [](const GraphDef& graph) -> FunctionLibraryDefinition {
689     return FunctionLibraryDefinition(OpRegistry::Global(), graph.library())
690         .ReachableDefinitions(graph);
691   };
692 
693   // 0. Original graph might contain a huge function library, that is mostly
694   // unused. This library copied over by each individual Grappler optimizer,
695   // which adds a huge overhead. Before starting optimization passes we just
696   // remove all the unreachable functions.
697   // TODO(ezhulenev): Construct reachable function library definition directly
698   // from the proto without constructing temporary FunctionLibraryDefinition.
699   int old_library_size = item.graph.library().function_size();
700   *item.graph.mutable_library() = minimized_flib(item.graph).ToProto();
701   int new_library_size = item.graph.library().function_size();
702 
703   VLOG(1) << absl::Substitute(
704       "Deleted $0 unreachable functions from the graph (library size = $1)",
705       old_library_size - new_library_size, new_library_size);
706 
707   // Save a few small fields from item before we move it.
708   bool optimize_function_library =
709       item.optimization_options().optimize_function_library;
710   const auto producer = item.graph.versions().producer();
711 
712   // 1. Optimize main graph
713   TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(item), optimized_graph));
714   VLOG(1) << "Optimized main graph.";
715   GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
716 
717   // 2. Optimize functions reachable from the optimized graph.
718   FunctionLibraryDefinition flib = minimized_flib(*optimized_graph);
719   using NodeDefs = protobuf::RepeatedPtrField<NodeDef>;
720 
721   // Find functions for which we might need to compute a gradient at runtime.
722   absl::flat_hash_set<string> differentiable_functions;
723 
724   const auto find_differentiable_functions =
725       [&](const NodeDefs& nodes) -> void {
726     for (const NodeDef& node : nodes) {
727       if (IsSymbolicGradient(node)) {
728         const auto* f_attr = gtl::FindOrNull(node.attr(), "f");
729         if (f_attr) differentiable_functions.insert(f_attr->func().name());
730       }
731     }
732   };
733 
734   // SymbolicGradient nodes inside the main graph.
735   find_differentiable_functions(optimized_graph->node());
736   // SymbolicGradient nodes inside the function library.
737   for (const FunctionDef& function : optimized_graph->library().function()) {
738     find_differentiable_functions(function.node_def());
739   }
740 
741   // Find functions that will be compiled by XLA later
742   // We do it by looking for XlaLaunch ops that call functions,
743   // then depth first search down those functions to find transitive functions.
744   // Grappler rewrites can potentially add nodes that are
745   // not supported by XLA, so we choose to skip such functions when we optimize
746   // the function library.
747   absl::flat_hash_set<string> xla_compiled_functions;
748   std::function<void(const string&)> find_all_functions;
749   find_all_functions = [&](const string& func) -> void {
750     // Ignore call cycles in the graph
751     if (xla_compiled_functions.contains(func)) return;
752     // Find func in the flib
753     const FunctionDef* func_def = flib.Find(func);
754     CHECK(func_def) << "not found: " << func;
755     // Mark function to be ignored by grappler
756     xla_compiled_functions.insert(func);
757     // Depth first search through the func for transitively called funcs
758     for (const NodeDef& node : func_def->node_def()) {
759       for (const auto attr : node.attr()) {
760         const AttrValue& attr_value = attr.second;
761         if (attr_value.has_func()) {
762           find_all_functions(attr_value.func().name());
763         }
764       }
765     }
766   };
767 
768   auto find_xla_compiled_functions = [&](const NodeDefs& nodes) -> void {
769     NameAttrList function;
770     for (const NodeDef& node : nodes) {
771       // Look only for XlaLaunch nodes that call a function
772       if (!IsXlaLaunch(node)) continue;
773       if (!GetNodeAttr(node, "function", &function).ok()) continue;
774       // Find all transitively called functions
775       find_all_functions(function.name());
776     }
777   };
778 
779   // XlaLaunch ops inside the main graph ...
780   find_xla_compiled_functions(optimized_graph->node());
781   // ... and inside the function library.
782   for (const FunctionDef& function : optimized_graph->library().function()) {
783     find_xla_compiled_functions(function.node_def());
784   }
785   // Propagate `_tf_data_function` attributes from functions to their callees.
786   PropagateTFDataAttrs(flib, *optimized_graph->mutable_library());
787 
788   // Optimize each function only once.
789   absl::flat_hash_set<string> optimized_funcs;
790   while (optimize_function_library) {
791     optimize_function_library = false;
792 
793     int function_idx = 0;
794     for (const FunctionDef& func : optimized_graph->library().function()) {
795       GRAPPLER_RETURN_IF_DEADLINE_EXCEEDED();
796 
797       const string& func_name = func.signature().name();
798 
799       // Skip functions that are not reachable from the optimized graph.
800       if (!flib.Contains(func_name)) continue;
801       // Skip already optimized functions.
802       if (optimized_funcs.contains(func_name)) continue;
803       // Skip functions that will be compiled by XLA.
804       if (xla_compiled_functions.contains(func_name)) continue;
805 
806       // Skip parametrized functions (function type or body is defined only at
807       // function call time by caller node attributes).
808       // They should be specialized to their instantiation type parameters by
809       // the function optimizer, before we can optimize function body.
810       if (IsParametrized(func)) continue;
811 
812       // Skip tf.data functions as they are optimized by tf.data meta optimizer
813       // and in function instantiation.
814       if (data::IsTFDataFunction(func)) continue;
815 
816       VLOG(3) << "Optimize function: function=" << func_name << " ["
817               << function_idx++ << " of "
818               << optimized_graph->library().function_size() << "]";
819 
820       // Function optimization might specialize nested function calls, so we
821       // have to reset the flag and do at least one more pass over the library.
822       optimize_function_library = true;
823       optimized_funcs.insert(func_name);
824 
825       // Make a GrapplerItem from a FunctionDef.
826       GrapplerFunctionItem func_item;
827       TF_RETURN_IF_ERROR(
828           MakeGrapplerFunctionItem(func, flib, producer, &func_item));
829 
830       // If we need to compute the gradient of optimized function at runtime, we
831       // can't perform non-differentiable rewrites.
832       func_item.optimization_options().allow_non_differentiable_rewrites =
833           !differentiable_functions.contains(func_name);
834 
835       // Device set available to the function is defined only by the runtime,
836       // when we instantiate and execute the function. We can't use all devices
837       // available to the main graph, because after partitioning the function
838       // call node might execute on a remote worker.
839       if (!func_item.devices().empty()) {
840         return errors::Internal("GrapplerFunctionItem devices must be empty.");
841       }
842 
843       // We are not allowed to prune certain types of ops from the graph
844       // instantiated by the function definition, because we must guarantee
845       // function execution semantics wrt side effects (see
846       // function_optimizer.cc).
847       func_item.optimization_options().allow_pruning_stateful_and_dataset_ops =
848           false;
849 
850       // Optimize function body graph.
851       GraphDef optimized_func_graph;
852       if (IsTPUGraphDef(*optimized_graph)) {
853         // Skip optimizing functions if this is a TPU graph. Currently, Grappler
854         // passes do not handle TPU functions correctly in a variety of ways
855         // (Note that due to the pre-placement TPU graph rewriting passes, the
856         // TPU-related ops are encapsulated away into functions). For example,
857         // TPU graphs contain TPUReplicateMetadata node that carries relevant
858         // TPU metadata and Grappler passes could prune that away. Grappler
859         // passes could also cause issues around shape inference. Since the
860         // desired and existing behavior is to not optimize TPU functions with
861         // Grappler, this check preserves that. The only exception is
862         // implementation selector what is required to swap in some TPU specific
863         // lowering code and is verified the work correctly on TPUs.
864         ImplementationSelector implementation_selector;
865 
866         // Implementation selector needs to have access to valid function
867         // signature and attributes, and it doesn't need actual function body.
868         FunctionDefLibrary func_item_function_library;
869         func_item_function_library.Swap(func_item.graph.mutable_library());
870         *func_item.graph.mutable_library() =
871             GetFunctionDefLibraryStub(func_item_function_library);
872 
873         TF_RETURN_IF_ERROR(implementation_selector.Optimize(
874             cluster, func_item, &optimized_func_graph));
875       } else {
876         GrapplerFunctionItem func_item_copy = func_item;
877         TF_RETURN_IF_ERROR(OptimizeGraph(cluster, std::move(func_item_copy),
878                                          &optimized_func_graph));
879       }
880 
881       // Function body optimization might have created new specialized
882       // functions for each instantiation context. Add them to the library.
883       for (const FunctionDef& func_def :
884            optimized_func_graph.library().function()) {
885         if (flib.Find(func_def.signature().name()) == nullptr) {
886           TF_RETURN_IF_ERROR(flib.AddFunctionDef(func_def));
887         }
888       }
889 
890       // Convert optimized graph back to FunctionDef.
891       FunctionDef optimized_func;
892       func_item.SwapFunctionBody(std::move(optimized_func_graph));
893       TF_RETURN_IF_ERROR(MakeFunctionDef(func_item, flib, &optimized_func));
894 
895       // Replace optimized function with a new FunctionDef.
896       TF_RETURN_IF_ERROR(flib.ReplaceFunction(func_name, optimized_func));
897     }
898 
899     // If optimized at least one function, update the graph library.
900     if (optimize_function_library) {
901       *optimized_graph->mutable_library() = flib.ToProto();
902     }
903   }
904 
905   VLOG(1) << "Optimized " << optimized_funcs.size()
906           << " functions: " << absl::StrJoin(optimized_funcs, ", ");
907   VLOG(3) << "Optimized graph =\n" << optimized_graph->DebugString();
908   if (VLOG_IS_ON(1)) {
909     DumpGraphDefToFile(
910         strings::StrCat("after_MetaOptimizer_",
911                         reinterpret_cast<uintptr_t>(optimized_graph)),
912         *optimized_graph);
913   }
914 
915   const uint64 end_us = Env::Default()->NowMicros();
916   metrics::UpdateGrapplerPassTime("*", end_us - start_us);
917 
918   return Status::OK();
919 }
920 
GetResultString() const921 string MetaOptimizer::GetResultString() const {
922   std::string result_string;
923   for (const GraphOptimizationResult& graph_result : optimization_results_) {
924     absl::StrAppend(&result_string,
925                     "Optimization results for grappler item: ", graph_result.id,
926                     "\n");
927     for (const OptimizerResult& result : graph_result.results) {
928       absl::StrAppend(&result_string, "  ", result.optimizer_name, ": ",
929                       result.message, "\n");
930     }
931   }
932   return result_string;
933 }
934 
PrintResult()935 void MetaOptimizer::PrintResult() { LOG(INFO) << GetResultString(); }
936 
MetaOptimizerEnabled(const ConfigProto & cfg)937 bool MetaOptimizerEnabled(const ConfigProto& cfg) {
938   const auto& rewrite_cfg = cfg.graph_options().rewrite_options();
939   if (rewrite_cfg.disable_meta_optimizer()) {
940     return false;
941   }
942   return !rewrite_cfg.disable_model_pruning() ||
943          rewrite_cfg.layout_optimizer() != RewriterConfig::OFF ||
944          rewrite_cfg.function_optimization() != RewriterConfig::OFF ||
945          rewrite_cfg.constant_folding() != RewriterConfig::OFF ||
946          rewrite_cfg.shape_optimization() != RewriterConfig::OFF ||
947          rewrite_cfg.remapping() != RewriterConfig::OFF ||
948          rewrite_cfg.common_subgraph_elimination() != RewriterConfig::OFF ||
949          rewrite_cfg.arithmetic_optimization() != RewriterConfig::OFF ||
950          rewrite_cfg.loop_optimization() != RewriterConfig::OFF ||
951          rewrite_cfg.dependency_optimization() != RewriterConfig::OFF ||
952          rewrite_cfg.auto_parallel().enable() ||
953          rewrite_cfg.memory_optimization() != RewriterConfig::NO_MEM_OPT ||
954          rewrite_cfg.debug_stripper() == RewriterConfig::ON ||
955          rewrite_cfg.scoped_allocator_optimization() == RewriterConfig::ON ||
956          rewrite_cfg.pin_to_host_optimization() == RewriterConfig::ON ||
957          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision()) ||
958          AutoMixedPrecisionEnabled(rewrite_cfg.auto_mixed_precision_mkl()) ||
959          !rewrite_cfg.optimizers().empty() ||
960          !rewrite_cfg.custom_optimizers().empty();
961 }
962 
RunMetaOptimizer(GrapplerItem && item,const ConfigProto & cfg,DeviceBase * cpu_device,Cluster * cluster,GraphDef * optimized_graph)963 Status RunMetaOptimizer(GrapplerItem&& item, const ConfigProto& cfg,
964                         DeviceBase* cpu_device, Cluster* cluster,
965                         GraphDef* optimized_graph) {
966   MetaOptimizer optimizer(cpu_device, cfg);
967   optimizer.set_deadline_usec(
968       DeadlineMicroSeconds(cfg.graph_options().rewrite_options()));
969   return optimizer.OptimizeConsumeItem(cluster, std::move(item),
970                                        optimized_graph);
971 }
972 
OptimizeGraph(std::vector<string> ret_node_names,std::vector<string> keep_node_names,FunctionLibraryDefinition * flib,const DeviceSet & device_set,Device * cpu_device,const ConfigProto & config_proto,const string & grappler_item_id,const GrapplerItem::OptimizationOptions & optimization_options,std::unique_ptr<tensorflow::Graph> * g)973 Status OptimizeGraph(
974     std::vector<string> ret_node_names, std::vector<string> keep_node_names,
975     FunctionLibraryDefinition* flib, const DeviceSet& device_set,
976     Device* cpu_device, const ConfigProto& config_proto,
977     const string& grappler_item_id,
978     const GrapplerItem::OptimizationOptions& optimization_options,
979     std::unique_ptr<tensorflow::Graph>* g) {
980   if (!tensorflow::grappler::MetaOptimizerEnabled(config_proto)) {
981     return Status::OK();
982   }
983 
984   tensorflow::grappler::GrapplerItem item;
985   item.id = grappler_item_id;
986   item.optimization_options() = optimization_options;
987 
988   // Add all available devices so that inlined function can be placed.
989   for (const Device* d : device_set.devices()) {
990     Status added_device = item.AddDevice(d->name());
991     if (!added_device.ok()) VLOG(3) << added_device.error_message();
992   }
993   VLOG(3) << "Grappler available devices: "
994           << absl::StrJoin(item.devices(), ", ");
995 
996   // Add fetches so that the graph can be pruned.
997   item.fetch.swap(ret_node_names);
998 
999   // Add noes that can't be removed from the graph.
1000   item.keep_ops = std::move(keep_node_names);
1001 
1002   (*g)->ToGraphDef(&item.graph);
1003 
1004   if (flib) {
1005     *item.graph.mutable_library() = flib->ToProto();
1006   }
1007 
1008   tensorflow::GraphDef out_graph;
1009   tensorflow::grappler::VirtualCluster cluster(&device_set);
1010   // TODO(nareshmodi): Consider adding and using the more generic GraphOptions
1011   // proto (which also contain the OptimizerOptions).
1012   TF_RETURN_IF_ERROR(tensorflow::grappler::RunMetaOptimizer(
1013       std::move(item), config_proto, cpu_device, &cluster, &out_graph));
1014 
1015   std::unique_ptr<tensorflow::Graph> optimized_graph(
1016       new tensorflow::Graph(OpRegistry::Global()));
1017 
1018   // Copy optimized functions back to the overlay lib.
1019   if (flib) {
1020     for (const FunctionDef& fdef : out_graph.library().function()) {
1021       const string& func_name = fdef.signature().name();
1022       if (flib->Contains(func_name)) {
1023         TF_RETURN_IF_ERROR(flib->ReplaceFunction(func_name, fdef));
1024       } else {
1025         TF_RETURN_IF_ERROR(flib->AddFunctionDef(fdef));
1026       }
1027     }
1028   }
1029 
1030   TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
1031       GraphConstructorOptions(), std::move(out_graph), optimized_graph.get()));
1032 
1033   // The graph conversion sets the requested device names but not the
1034   // assigned device names. However, since at this point the graph is
1035   // placed TF expects an assigned device name for every node. Therefore
1036   // we copy the requested device into the assigned device field.
1037   for (Node* node : optimized_graph->nodes()) {
1038     if (node->IsOp() && node->assigned_device_name().empty()) {
1039       if (node->requested_device().empty()) {
1040         return errors::Internal(
1041             "Either placer did not place the node or Grappler did not "
1042             "copy the assigned device. Contact Grappler team since latter "
1043             "is more likely. Node=",
1044             node->name(),
1045             " Graph: ", optimized_graph->ToGraphDefDebug().DebugString());
1046       }
1047       node->set_assigned_device_name(node->requested_device());
1048     }
1049   }
1050 
1051   *g = std::move(optimized_graph);
1052   return Status::OK();
1053 }
1054 
1055 }  // namespace grappler
1056 }  // namespace tensorflow
1057