1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2 
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6 
7     http://www.apache.org/licenses/LICENSE-2.0
8 
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16 
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20 
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variable.pb.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/graph_constructor.h"
38 #include "tensorflow/core/grappler/inputs/utils.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/lib/io/path.h"
44 #include "tensorflow/core/platform/protobuf_internal.h"
45 #include "tensorflow/core/protobuf/meta_graph.pb.h"
46 #include "tensorflow/core/protobuf/saver.pb.h"
47 #include "tensorflow/core/public/session_options.h"
48 
49 namespace tensorflow {
50 namespace grappler {
51 
52 namespace {
53 
InitializeTensor(DataType type,Tensor * tensor)54 void InitializeTensor(DataType type, Tensor* tensor) {
55   const int period = 7;
56   if (type == DT_FLOAT) {
57     auto flat = tensor->flat<float>();
58     // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
59     for (int i = 0; i < flat.size(); i++) {
60       flat(i) = static_cast<float>(i % period) / 10.0f;
61     }
62   } else if (type == DT_INT64) {
63     auto flat = tensor->flat<int64>();
64     // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
65     for (int i = 0; i < flat.size(); i++) {
66       flat(i) = i % period;
67     }
68   } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
69     // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
70     // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
71     // Allocator will run non-trivial constructor/destructor for a Tensor with
72     // one of these types, so we should not memset its buffer.
73     memset(const_cast<char*>(tensor->tensor_data().data()), 0,
74            tensor->tensor_data().size());
75   }
76 }
77 
78 // Optimize the graph def (including function inlining and other optimizations).
79 // This is a temporary change that optimizes the graph in context of a single
80 // gpu machine. Down the line, we may want to make grappler_item_builder aware
81 // of the cluster type (E.g: single cpu, multiple gpu, etc)  being simulated in
82 // order to get the correct session options and environment, and performing the
83 // correct optimizations.
OptimizeGraph(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)84 Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
85                      const ItemConfig& cfg) {
86   if (!cfg.apply_optimizations && !cfg.erase_noinline_attributes) {
87     return Status::OK();
88   }
89 
90   // Create a session option for a single GPU device.
91   SessionOptions options;
92 
93   // Make a local copy of graph def, because we need to change some things.
94   GraphDef graph_def(graph_def_arg);
95 
96   if (cfg.erase_noinline_attributes) {
97     // TF optimizer doesn't inline functions with "_noinline" attribute,
98     // so let's go over the function library and erase it.
99     for (auto& func : *graph_def.mutable_library()->mutable_function()) {
100       func.mutable_attr()->erase("_noinline");
101     }
102   }
103 
104   // Instantiate all variables for function library runtime creation.
105   std::vector<std::unique_ptr<Device>> devices;
106   // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
107   // with dummy session config, which will conflict with user defined options
108   // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
109   // only devices.
110   DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
111   TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
112       options, "/job:localhost/replica:0/task:0", &devices));
113   Device* cpu_device = devices[0].get();
114   std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
115   FunctionLibraryDefinition function_library(OpRegistry::Global(),
116                                              graph_def.library());
117   Env* env = Env::Default();
118 
119   // Optimizer options: L1 and inlining. L1 is default.
120   OptimizerOptions* optimizer_opts =
121       options.config.mutable_graph_options()->mutable_optimizer_options();
122   if (cfg.apply_optimizations) {
123     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L1);
124   } else {
125     optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0);
126   }
127 
128   // Create the function library runtime.
129   std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
130       new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
131                                         graph_def.versions().producer(),
132                                         &function_library, *optimizer_opts));
133   FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
134 
135   // Create the GraphOptimizer to optimize the graph def.
136   GraphConstructorOptions graph_ctor_opts;
137   graph_ctor_opts.allow_internal_ops = true;
138   graph_ctor_opts.expect_device_spec = false;
139   std::unique_ptr<Graph> graphptr(new Graph(function_library));
140 
141   TF_RETURN_IF_ERROR(
142       ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
143 
144   // Optimize the graph.
145   ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
146   optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
147   graphptr->ToGraphDef(output_graph_def);
148 
149   // The default values of attributes might have been stripped by the optimizer.
150   // Add them back.
151   return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
152                                    0, true);
153 }
154 
155 // Applies the same graph pruning logic to the graph as Session.Run in TF.
156 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)157 Status PruneGraph(GrapplerItem* item) {
158   ModelPruner pruner;
159   GraphDef pruned_graph;
160   Cluster* cluster = nullptr;  // ModelPruner doesn't check cluster.
161   TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
162   item->graph = std::move(pruned_graph);
163   return Status::OK();
164 }
165 
166 // Replace any unknown dimensions in a shape with
167 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)168 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
169                               const TensorShapeProto& shape_pb_in,
170                               TensorShapeProto* shape_pb_out,
171                               TensorShape* shape_out) {
172   std::vector<int32> dims;
173   for (const auto& dim_proto : shape_pb_in.dim()) {
174     if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
175         dim_proto.size() == -1) {
176       dims.push_back(cfg.placeholder_unknown_output_shape_dim);
177       shape_pb_out->add_dim()->set_size(
178           cfg.placeholder_unknown_output_shape_dim);
179     } else {
180       dims.push_back(std::max<int32>(1, dim_proto.size()));
181       shape_pb_out->add_dim()->set_size(dim_proto.size());
182     }
183   }
184   return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
185 }
186 
187 }  // namespace
188 
189 // static
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)190 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
191     const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
192   if (id.empty()) {
193     LOG(ERROR) << "id must be non-empty.";
194     return nullptr;
195   }
196   std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
197   new_item->id = id;
198   new_item->graph = meta_graph.graph_def();
199 
200   // Fill in feed nodes from config, if any provided.
201   for (const auto& feed_node : cfg.feed_nodes) {
202     const string feed_name = NodeName(feed_node);
203     new_item->feed.emplace_back(feed_name, Tensor());
204   }
205   for (const auto& fetch_node : cfg.fetch_nodes) {
206     new_item->fetch.emplace_back(NodeName(fetch_node));
207   }
208 
209   // Attempt to detect the fetch node(s) if they were not set explicitly.
210   if (new_item->fetch.empty() &&
211       meta_graph.collection_def().count("train_op") > 0) {
212     const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
213     if (nodes.has_node_list()) {
214       for (const auto& node : nodes.node_list().value()) {
215         new_item->fetch.push_back(NodeName(node));
216       }
217     }
218   }
219 
220   // Detect feed and fetch nodes from signature defs. Signatures may share same
221   // inputs or outputs.
222   std::unordered_set<string> signature_feed_nodes;
223   std::unordered_set<string> signature_fetch_nodes;
224   for (const auto& name_and_signature : meta_graph.signature_def()) {
225     for (const auto& name_and_input : name_and_signature.second.inputs()) {
226       const TensorInfo& input = name_and_input.second;
227       if (input.has_coo_sparse()) {
228         // Define the shapes following the comment of CooSparse.
229         // TODO(yuefengz): we probably want to use different dim values for the
230         // three tensors of a SparseTensor.
231         int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
232         TensorShape shape_1d({dim});
233         TensorShape shape_2d({dim, dim});
234 
235         if (gtl::InsertIfNotPresent(
236                 &signature_feed_nodes,
237                 NodeName(input.coo_sparse().values_tensor_name()))) {
238           Tensor value_tensor(input.dtype(), shape_1d);
239           InitializeTensor(input.dtype(), &value_tensor);
240           new_item->feed.emplace_back(
241               NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
242         }
243         if (gtl::InsertIfNotPresent(
244                 &signature_feed_nodes,
245                 NodeName(input.coo_sparse().indices_tensor_name()))) {
246           Tensor indices_tensor(DT_INT64, shape_2d);
247           InitializeTensor(input.dtype(), &indices_tensor);
248           new_item->feed.emplace_back(
249               NodeName(input.coo_sparse().indices_tensor_name()),
250               indices_tensor);
251         }
252         if (gtl::InsertIfNotPresent(
253                 &signature_feed_nodes,
254                 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
255           Tensor dense_shape_tensor(DT_INT64, shape_1d);
256           InitializeTensor(input.dtype(), &dense_shape_tensor);
257           new_item->feed.emplace_back(
258               NodeName(input.coo_sparse().dense_shape_tensor_name()),
259               dense_shape_tensor);
260         }
261       } else {
262         if (gtl::InsertIfNotPresent(&signature_feed_nodes,
263                                     NodeName(input.name()))) {
264           TensorShape shape;
265           TensorShapeProto shape_proto;
266           Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
267                                             &shape_proto, &shape);
268           if (!s.ok()) {
269             LOG(ERROR) << "Invalid shape for signature input " << input.name()
270                        << ": " << s << ", skipping this input";
271             return nullptr;
272           }
273 
274           Tensor fake_input(input.dtype(), shape);
275           InitializeTensor(input.dtype(), &fake_input);
276           new_item->feed.emplace_back(NodeName(input.name()), fake_input);
277         }
278       }
279     }
280     for (const auto& name_and_output : name_and_signature.second.outputs()) {
281       const TensorInfo& output = name_and_output.second;
282       if (output.has_coo_sparse()) {
283         if (gtl::InsertIfNotPresent(
284                 &signature_fetch_nodes,
285                 NodeName(output.coo_sparse().values_tensor_name()))) {
286           new_item->fetch.push_back(
287               NodeName(output.coo_sparse().values_tensor_name()));
288         }
289         if (gtl::InsertIfNotPresent(
290                 &signature_fetch_nodes,
291                 NodeName(output.coo_sparse().indices_tensor_name()))) {
292           new_item->fetch.push_back(
293               NodeName(output.coo_sparse().indices_tensor_name()));
294         }
295         if (gtl::InsertIfNotPresent(
296                 &signature_fetch_nodes,
297                 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
298           new_item->fetch.push_back(
299               NodeName(output.coo_sparse().dense_shape_tensor_name()));
300         }
301       } else {
302         if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
303                                     NodeName(output.name()))) {
304           new_item->fetch.push_back(NodeName(output.name()));
305         }
306       }
307     }
308   }
309 
310   for (const auto& feed : new_item->feed) {
311     if (feed.first.empty()) {
312       LOG(ERROR) << "Invalid feed node name skipping this input";
313       return nullptr;
314     } else {
315       VLOG(1) << "Will use feed node " << feed.first;
316     }
317   }
318 
319   for (const auto& fetch : new_item->fetch) {
320     if (fetch.empty()) {
321       LOG(ERROR) << "Invalid fetch node name skipping this input";
322       return nullptr;
323     } else {
324       VLOG(1) << "Will use fetch node " << fetch;
325     }
326   }
327 
328   if (new_item->fetch.empty()) {
329     LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
330     return nullptr;
331   }
332 
333   // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
334   // The reason why they are difficult to handle is because they may not intend
335   // to initialize all variables that are required to run fetch nodes. We may
336   // have to run restore op first.
337 
338   // Try to find initializers from variables and tables as init ops.
339   for (const string& var_collection :
340        {"variables", "local_variables", "model_variables",
341         "trainable_variables"}) {
342     if (meta_graph.collection_def().count(var_collection) == 0) {
343       continue;
344     }
345     const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
346     for (const auto& raw_var : vars.bytes_list().value()) {
347       VariableDef var;
348       var.ParseFromString(raw_var);
349       if (!var.initializer_name().empty()) {
350         new_item->init_ops.push_back(NodeName(var.initializer_name()));
351       }
352     }
353   }
354 
355   if (meta_graph.collection_def().count("table_initializer") > 0) {
356     const CollectionDef& inits =
357         meta_graph.collection_def().at("table_initializer");
358     if (inits.has_node_list()) {
359       for (const auto& node : inits.node_list().value()) {
360         new_item->init_ops.push_back(NodeName(node));
361         // Tables are initialized from files, which can take a long time. Add
362         // 30 minutes to the initialization time for each table to avoid
363         // timing out.
364         // TODO(bsteiner): adjust the timeout based on the file size.
365         new_item->expected_init_time += 30 * 60;
366       }
367     }
368   }
369 
370   // We keep the mapping from asset node to asset files. This should have been
371   // used as feed but since asset node is usually a constant node, we will fill
372   // the values of these constant nodes with their actual asset file paths.
373   std::unordered_map<string, string> asset_node_to_value;
374 
375   // Assets file may have changed their directory, we assemble their new paths
376   // if assets_directory_override is set. We also make sure we still can
377   // access these asset files.
378   if (!cfg.assets_directory_override.empty()) {
379     if (meta_graph.collection_def().count("saved_model_assets") > 0) {
380       const CollectionDef& collection =
381           meta_graph.collection_def().at("saved_model_assets");
382       const auto& any_assets = collection.any_list().value();
383       for (const auto& any_asset : any_assets) {
384         AssetFileDef asset_file_def;
385         if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
386                  .ok()) {
387           LOG(ERROR) << "Failed to parse AssetFile.";
388           continue;
389         }
390         string asset_filepath = io::JoinPath(cfg.assets_directory_override,
391                                              asset_file_def.filename());
392         if (!FilesExist({asset_filepath}, nullptr)) {
393           LOG(ERROR) << "Can't access one or more of the asset files "
394                      << asset_filepath << ", skipping this input";
395           return nullptr;
396         }
397         asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
398             asset_filepath;
399       }
400     }
401   } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
402     const CollectionDef& file_paths =
403         meta_graph.collection_def().at("asset_filepaths");
404     std::vector<string> paths;
405     for (const auto& raw_path : file_paths.bytes_list().value()) {
406       paths.push_back(raw_path);
407     }
408     if (!FilesExist(paths, nullptr)) {
409       LOG(ERROR) << "Can't access one or more of the asset files, skipping "
410                     "this input";
411       return nullptr;
412     }
413   }
414 
415   if (meta_graph.collection_def().count("queue_runners") > 0) {
416     const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
417     for (const auto& raw : vars.bytes_list().value()) {
418       QueueRunnerDef queue_runner;
419       if (!queue_runner.ParseFromString(raw)) {
420         LOG(ERROR) << "Could not parse queue_runners, skipping this input";
421         return nullptr;
422       }
423       if (queue_runner.cancel_op_name().empty()) {
424         LOG(ERROR) << "Queue without a cancel op, skipping this input";
425         return nullptr;
426       }
427       new_item->queue_runners.push_back(queue_runner);
428     }
429   }
430 
431   // Add each node referenced in a collection to the list of nodes to keep.
432   for (const auto& col : meta_graph.collection_def()) {
433     const CollectionDef& collection = col.second;
434     for (const string& node : collection.node_list().value()) {
435       new_item->keep_ops.push_back(NodeName(node));
436     }
437   }
438 
439   for (auto& node : *new_item->graph.mutable_node()) {
440     if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
441       if (node.attr().count("dtype") == 0) {
442         LOG(ERROR) << "Unknown type for placeholder " << node.name()
443                    << ", skipping this input";
444         return nullptr;
445       }
446       DataType type = node.attr().at("dtype").type();
447 
448       if (node.attr().count("shape") == 0) {
449         LOG(INFO) << "Unknown shape for placeholder " << node.name()
450                   << ", skipping this input";
451         return nullptr;
452       }
453 
454       // Replace all unknown dimensions in the placeholder's tensorshape proto
455       // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
456       // from it. We do this because in newer protos, the input placeholder
457       // shape is not empty if the shape is partially defined.
458       TensorShape shape;
459       TensorShapeProto shape_proto;
460       Status make_shape_status = ReplaceUnknownShapeDim(
461           cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
462       if (!make_shape_status.ok()) {
463         LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
464                    << make_shape_status << ", skipping this input";
465         return nullptr;
466       }
467 
468       // Some placeholder nodes have a mis-match between the node
469       // attribute "shape" and a different node attribute "_output_shapes".
470       // Specifically, a shape with shape.dims() == 0 could indicate either
471       // a scalar or an unknown shape. In those cases, we check _output_shapes
472       // for additional information.
473       // This case is observed in the bnmt graphs. Have not observed any
474       // cases where there was more than 1 _output_shapes, so limit it
475       // to cases where there is only 1 _output_shapes.
476       // We only do this if cfg.placeholder_unknown_output_shape_dim has
477       // been set to avoid crashing non-BNMT graphs.
478       if ((cfg.placeholder_unknown_output_shape_dim >= 0) &&
479           (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1)) {
480         const auto& output_shapes =
481             node.attr().at("_output_shapes").list().shape(0);
482 
483         if (output_shapes.dim_size() != 0) {
484           shape.Clear();
485           shape_proto.clear_dim();
486 
487           for (const auto& dim : output_shapes.dim()) {
488             auto size = dim.size();
489             if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
490             shape.AddDim(size);
491             shape_proto.add_dim()->set_size(size);
492           }
493         }
494       }
495 
496       Tensor fake_input(type, shape);
497       InitializeTensor(type, &fake_input);
498 
499       if (cfg.feed_nodes.empty()) {
500         // No specific feed nodes were given. Assume all placeholders are fed.
501         if (signature_feed_nodes.count(node.name()) == 0) {
502           new_item->feed.emplace_back(node.name(), fake_input);
503         }
504       } else if (cfg.feed_nodes.count(node.name()) > 0) {
505         // If specific feed nodes were given, only update their tensors.
506         auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
507                           [&node](std::pair<string, Tensor>& f) {
508                             return f.first == node.name();
509                           });
510         QCHECK(it != new_item->feed.end());
511         it->second = fake_input;
512       }
513 
514       // Set the shape of the node in the graph. This is needed for statically
515       // inferring shapes and is a no-op when dynamically inferring shapes as
516       // the Placeholder shape will match the shape passed from new_item->feed.
517       *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
518     } else if (IsConstant(node)) {
519       auto it = asset_node_to_value.find(node.name());
520       if (it != asset_node_to_value.end()) {
521         auto iter = node.mutable_attr()->find("value");
522         if (iter == node.attr().end()) {
523           LOG(ERROR) << "Value attribute expected in const op for asset files";
524           return nullptr;
525         }
526         if (!iter->second.has_tensor() ||
527             iter->second.tensor().string_val_size() != 1) {
528           LOG(INFO) << "Unexpected AttrValue proto: "
529                     << iter->second.DebugString();
530           return nullptr;
531         }
532         LOG(INFO) << "Using asset file " << it->second << " for node "
533                   << node.name();
534         *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
535       }
536     }
537 
538     // Erase the recorded result of any previous shape inference to start again
539     // from scratch.
540     node.mutable_attr()->erase("_output_shapes");
541 
542     // Delete user specified placement if requested.
543     if (cfg.ignore_user_placement) {
544       node.clear_device();
545     }
546     // Delete colocation constraints if requested.
547     if (cfg.ignore_colocation) {
548       auto attr = node.mutable_attr();
549       auto it = attr->find("_class");
550       if (it != attr->end()) {
551         attr->erase(it);
552       }
553     }
554   }
555 
556   if (meta_graph.collection_def().count("savers") > 0) {
557     const CollectionDef& savers = meta_graph.collection_def().at("savers");
558     for (const auto& raw : savers.bytes_list().value()) {
559       SaverDef saver;
560       // Skip bad savers since we don't need saves/restores to be able to run a
561       // graph.
562       if (!saver.ParseFromString(raw)) {
563         continue;
564       }
565       if (saver.filename_tensor_name().empty()) {
566         continue;
567       }
568       new_item->save_op = saver.save_tensor_name();
569       new_item->restore_op = saver.restore_op_name();
570       new_item->save_restore_loc_tensor = saver.filename_tensor_name();
571       // Only use the first saver since it's not clear what to do if there's
572       // more than one.
573       break;
574     }
575   } else {
576     const SaverDef& saver = meta_graph.saver_def();
577     new_item->save_op = saver.save_tensor_name();
578     new_item->restore_op = saver.restore_op_name();
579     new_item->save_restore_loc_tensor = saver.filename_tensor_name();
580   }
581 
582   // Instantiate all the missing attributes with their default values.
583   Status attr_status = AddDefaultAttrsToGraphDef(
584       &new_item->graph,
585       FunctionLibraryDefinition(OpRegistry::Global(),
586                                 new_item->graph.library()),
587       0, true);
588   if (!attr_status.ok()) {
589     LOG(ERROR) << "Failed to instantiate default attribute values: "
590                << attr_status.error_message();
591     return nullptr;
592   }
593 
594   // Optimize the graph (function inlining, l1 optimizations, etc).
595   VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
596           << new_item->graph.node_size();
597   Status optimize_status =
598       OptimizeGraph(new_item->graph, &new_item->graph, cfg);
599   if (!optimize_status.ok()) {
600     LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
601     return nullptr;
602   }
603   VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
604           << new_item->graph.node_size();
605 
606   if (cfg.prune_graph) {
607     VLOG(1) << "Pruning graph...";
608     auto status = PruneGraph(new_item.get());
609     if (!status.ok()) {
610       LOG(ERROR) << "Pruning failed: " << status.error_message();
611       return nullptr;
612     }
613     VLOG(1) << "Number of nodes in graph after pruning: "
614             << new_item->graph.node_size();
615   }
616 
617   // Validate feed, fetch and init nodes
618   std::unordered_set<string> nodes;
619   for (const auto& node : new_item->graph.node()) {
620     nodes.insert(node.name());
621   }
622   for (const auto& feed : new_item->feed) {
623     if (nodes.find(feed.first) == nodes.end()) {
624       LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
625       return nullptr;
626     }
627   }
628   for (const auto& fetch : new_item->fetch) {
629     if (nodes.find(fetch) == nodes.end()) {
630       LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
631       return nullptr;
632     }
633   }
634   for (const auto& init : new_item->init_ops) {
635     if (nodes.find(init) == nodes.end()) {
636       LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
637       return nullptr;
638     }
639   }
640   return new_item;
641 }
642 
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)643 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
644     const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
645   MetaGraphDef meta_graph;
646   if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
647     return nullptr;
648   }
649   return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
650 }
651 
652 }  // end namespace grappler
653 }  // end namespace tensorflow
654