1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3 Licensed under the Apache License, Version 2.0 (the "License");
4 you may not use this file except in compliance with the License.
5 You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9 Unless required by applicable law or agreed to in writing, software
10 distributed under the License is distributed on an "AS IS" BASIS,
11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 See the License for the specific language governing permissions and
13 limitations under the License.
14 ==============================================================================*/
15 #include "tensorflow/core/grappler/grappler_item_builder.h"
16
17 #include <unordered_map>
18 #include <unordered_set>
19 #include <vector>
20
21 #include "tensorflow/core/common_runtime/device.h"
22 #include "tensorflow/core/common_runtime/device_factory.h"
23 #include "tensorflow/core/common_runtime/device_mgr.h"
24 #include "tensorflow/core/common_runtime/function.h"
25 #include "tensorflow/core/common_runtime/graph_optimizer.h"
26 #include "tensorflow/core/framework/attr_value.pb.h"
27 #include "tensorflow/core/framework/function.h"
28 #include "tensorflow/core/framework/function.pb.h"
29 #include "tensorflow/core/framework/graph_def_util.h"
30 #include "tensorflow/core/framework/node_def.pb.h"
31 #include "tensorflow/core/framework/op.h"
32 #include "tensorflow/core/framework/tensor.pb.h"
33 #include "tensorflow/core/framework/tensor_shape.pb.h"
34 #include "tensorflow/core/framework/types.pb.h"
35 #include "tensorflow/core/framework/variable.pb.h"
36 #include "tensorflow/core/framework/versions.pb.h"
37 #include "tensorflow/core/graph/graph_constructor.h"
38 #include "tensorflow/core/grappler/inputs/utils.h"
39 #include "tensorflow/core/grappler/op_types.h"
40 #include "tensorflow/core/grappler/optimizers/model_pruner.h"
41 #include "tensorflow/core/grappler/utils.h"
42 #include "tensorflow/core/lib/gtl/map_util.h"
43 #include "tensorflow/core/lib/io/path.h"
44 #include "tensorflow/core/platform/protobuf_internal.h"
45 #include "tensorflow/core/protobuf/meta_graph.pb.h"
46 #include "tensorflow/core/protobuf/saver.pb.h"
47 #include "tensorflow/core/public/session_options.h"
48
49 namespace tensorflow {
50 namespace grappler {
51
52 namespace {
53
InitializeTensor(DataType type,Tensor * tensor)54 void InitializeTensor(DataType type, Tensor* tensor) {
55 const int period = 7;
56 if (type == DT_FLOAT) {
57 auto flat = tensor->flat<float>();
58 // Populate numbers 0, 0.1, 0.2, ..., 0.5, 0.6, 0, 0.1, 0.2, ...
59 for (int i = 0; i < flat.size(); i++) {
60 flat(i) = static_cast<float>(i % period) / 10.0f;
61 }
62 } else if (type == DT_INT64) {
63 auto flat = tensor->flat<int64>();
64 // Populate numbers 0, 1, 2, ..., 5, 6, 0, 1, 2, ...
65 for (int i = 0; i < flat.size(); i++) {
66 flat(i) = i % period;
67 }
68 } else if (type != DT_STRING && type != DT_RESOURCE && type != DT_VARIANT) {
69 // DT_STRING, DT_RESOURCE and DT_VARIANT are not simple types according to
70 // is_simple_type<> in tensorflow/core/framework/type_traits.h, and
71 // Allocator will run non-trivial constructor/destructor for a Tensor with
72 // one of these types, so we should not memset its buffer.
73 memset(const_cast<char*>(tensor->tensor_data().data()), 0,
74 tensor->tensor_data().size());
75 }
76 }
77
78 // Optimize the graph def (including function inlining and other optimizations).
79 // This is a temporary change that optimizes the graph in context of a single
80 // gpu machine. Down the line, we may want to make grappler_item_builder aware
81 // of the cluster type (E.g: single cpu, multiple gpu, etc) being simulated in
82 // order to get the correct session options and environment, and performing the
83 // correct optimizations.
OptimizeGraph(const GraphDef & graph_def_arg,GraphDef * output_graph_def,const ItemConfig & cfg)84 Status OptimizeGraph(const GraphDef& graph_def_arg, GraphDef* output_graph_def,
85 const ItemConfig& cfg) {
86 if (!cfg.apply_optimizations && !cfg.erase_noinline_attributes) {
87 return Status::OK();
88 }
89
90 // Create a session option for a single GPU device.
91 SessionOptions options;
92
93 // Make a local copy of graph def, because we need to change some things.
94 GraphDef graph_def(graph_def_arg);
95
96 if (cfg.erase_noinline_attributes) {
97 // TF optimizer doesn't inline functions with "_noinline" attribute,
98 // so let's go over the function library and erase it.
99 for (auto& func : *graph_def.mutable_library()->mutable_function()) {
100 func.mutable_attr()->erase("_noinline");
101 }
102 }
103
104 // Instantiate all variables for function library runtime creation.
105 std::vector<std::unique_ptr<Device>> devices;
106 // Only CPU device is used so instead of calling DeviceFactory::AddDevices()
107 // with dummy session config, which will conflict with user defined options
108 // and create unwanted devices, call cpu_factory->CreateDevices() to get CPU
109 // only devices.
110 DeviceFactory* cpu_factory = DeviceFactory::GetFactory("CPU");
111 TF_RETURN_IF_ERROR(cpu_factory->CreateDevices(
112 options, "/job:localhost/replica:0/task:0", &devices));
113 Device* cpu_device = devices[0].get();
114 std::unique_ptr<DeviceMgr> dvc_mgr(new DeviceMgr(std::move(devices)));
115 FunctionLibraryDefinition function_library(OpRegistry::Global(),
116 graph_def.library());
117 Env* env = Env::Default();
118
119 // Optimizer options: L1 and inlining. L1 is default.
120 OptimizerOptions* optimizer_opts =
121 options.config.mutable_graph_options()->mutable_optimizer_options();
122 if (cfg.apply_optimizations) {
123 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L1);
124 } else {
125 optimizer_opts->set_opt_level(::tensorflow::OptimizerOptions_Level_L0);
126 }
127
128 // Create the function library runtime.
129 std::unique_ptr<ProcessFunctionLibraryRuntime> pflr(
130 new ProcessFunctionLibraryRuntime(dvc_mgr.get(), env,
131 graph_def.versions().producer(),
132 &function_library, *optimizer_opts));
133 FunctionLibraryRuntime* flr = pflr->GetFLR(cpu_device->name());
134
135 // Create the GraphOptimizer to optimize the graph def.
136 GraphConstructorOptions graph_ctor_opts;
137 graph_ctor_opts.allow_internal_ops = true;
138 graph_ctor_opts.expect_device_spec = false;
139 std::unique_ptr<Graph> graphptr(new Graph(function_library));
140
141 TF_RETURN_IF_ERROR(
142 ConvertGraphDefToGraph(graph_ctor_opts, graph_def, graphptr.get()));
143
144 // Optimize the graph.
145 ::tensorflow::GraphOptimizer optimizer(*optimizer_opts);
146 optimizer.Optimize(flr, env, cpu_device, &graphptr, /*shape_map=*/nullptr);
147 graphptr->ToGraphDef(output_graph_def);
148
149 // The default values of attributes might have been stripped by the optimizer.
150 // Add them back.
151 return AddDefaultAttrsToGraphDef(output_graph_def, *graphptr->op_registry(),
152 0, true);
153 }
154
155 // Applies the same graph pruning logic to the graph as Session.Run in TF.
156 // If the returned status is not OK, item state may be inconsistent.
PruneGraph(GrapplerItem * item)157 Status PruneGraph(GrapplerItem* item) {
158 ModelPruner pruner;
159 GraphDef pruned_graph;
160 Cluster* cluster = nullptr; // ModelPruner doesn't check cluster.
161 TF_RETURN_IF_ERROR(pruner.Optimize(cluster, *item, &pruned_graph));
162 item->graph = std::move(pruned_graph);
163 return Status::OK();
164 }
165
166 // Replace any unknown dimensions in a shape with
167 // cfg.placeholder_unknown_output_shape_dim if it is no less than 0.
ReplaceUnknownShapeDim(const ItemConfig & cfg,const TensorShapeProto & shape_pb_in,TensorShapeProto * shape_pb_out,TensorShape * shape_out)168 Status ReplaceUnknownShapeDim(const ItemConfig& cfg,
169 const TensorShapeProto& shape_pb_in,
170 TensorShapeProto* shape_pb_out,
171 TensorShape* shape_out) {
172 std::vector<int32> dims;
173 for (const auto& dim_proto : shape_pb_in.dim()) {
174 if (cfg.placeholder_unknown_output_shape_dim >= 0 &&
175 dim_proto.size() == -1) {
176 dims.push_back(cfg.placeholder_unknown_output_shape_dim);
177 shape_pb_out->add_dim()->set_size(
178 cfg.placeholder_unknown_output_shape_dim);
179 } else {
180 dims.push_back(std::max<int32>(1, dim_proto.size()));
181 shape_pb_out->add_dim()->set_size(dim_proto.size());
182 }
183 }
184 return TensorShapeUtils::MakeShape(dims.data(), dims.size(), shape_out);
185 }
186
187 } // namespace
188
189 // static
GrapplerItemFromMetaGraphDef(const string & id,const MetaGraphDef & meta_graph,const ItemConfig & cfg)190 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDef(
191 const string& id, const MetaGraphDef& meta_graph, const ItemConfig& cfg) {
192 if (id.empty()) {
193 LOG(ERROR) << "id must be non-empty.";
194 return nullptr;
195 }
196 std::unique_ptr<GrapplerItem> new_item(new GrapplerItem());
197 new_item->id = id;
198 new_item->graph = meta_graph.graph_def();
199
200 // Fill in feed nodes from config, if any provided.
201 for (const auto& feed_node : cfg.feed_nodes) {
202 const string feed_name = NodeName(feed_node);
203 new_item->feed.emplace_back(feed_name, Tensor());
204 }
205 for (const auto& fetch_node : cfg.fetch_nodes) {
206 new_item->fetch.emplace_back(NodeName(fetch_node));
207 }
208
209 // Attempt to detect the fetch node(s) if they were not set explicitly.
210 if (new_item->fetch.empty() &&
211 meta_graph.collection_def().count("train_op") > 0) {
212 const CollectionDef& nodes = meta_graph.collection_def().at("train_op");
213 if (nodes.has_node_list()) {
214 for (const auto& node : nodes.node_list().value()) {
215 new_item->fetch.push_back(NodeName(node));
216 }
217 }
218 }
219
220 // Detect feed and fetch nodes from signature defs. Signatures may share same
221 // inputs or outputs.
222 std::unordered_set<string> signature_feed_nodes;
223 std::unordered_set<string> signature_fetch_nodes;
224 for (const auto& name_and_signature : meta_graph.signature_def()) {
225 for (const auto& name_and_input : name_and_signature.second.inputs()) {
226 const TensorInfo& input = name_and_input.second;
227 if (input.has_coo_sparse()) {
228 // Define the shapes following the comment of CooSparse.
229 // TODO(yuefengz): we probably want to use different dim values for the
230 // three tensors of a SparseTensor.
231 int64 dim = std::max(1, cfg.placeholder_unknown_output_shape_dim);
232 TensorShape shape_1d({dim});
233 TensorShape shape_2d({dim, dim});
234
235 if (gtl::InsertIfNotPresent(
236 &signature_feed_nodes,
237 NodeName(input.coo_sparse().values_tensor_name()))) {
238 Tensor value_tensor(input.dtype(), shape_1d);
239 InitializeTensor(input.dtype(), &value_tensor);
240 new_item->feed.emplace_back(
241 NodeName(input.coo_sparse().values_tensor_name()), value_tensor);
242 }
243 if (gtl::InsertIfNotPresent(
244 &signature_feed_nodes,
245 NodeName(input.coo_sparse().indices_tensor_name()))) {
246 Tensor indices_tensor(DT_INT64, shape_2d);
247 InitializeTensor(input.dtype(), &indices_tensor);
248 new_item->feed.emplace_back(
249 NodeName(input.coo_sparse().indices_tensor_name()),
250 indices_tensor);
251 }
252 if (gtl::InsertIfNotPresent(
253 &signature_feed_nodes,
254 NodeName(input.coo_sparse().dense_shape_tensor_name()))) {
255 Tensor dense_shape_tensor(DT_INT64, shape_1d);
256 InitializeTensor(input.dtype(), &dense_shape_tensor);
257 new_item->feed.emplace_back(
258 NodeName(input.coo_sparse().dense_shape_tensor_name()),
259 dense_shape_tensor);
260 }
261 } else {
262 if (gtl::InsertIfNotPresent(&signature_feed_nodes,
263 NodeName(input.name()))) {
264 TensorShape shape;
265 TensorShapeProto shape_proto;
266 Status s = ReplaceUnknownShapeDim(cfg, input.tensor_shape(),
267 &shape_proto, &shape);
268 if (!s.ok()) {
269 LOG(ERROR) << "Invalid shape for signature input " << input.name()
270 << ": " << s << ", skipping this input";
271 return nullptr;
272 }
273
274 Tensor fake_input(input.dtype(), shape);
275 InitializeTensor(input.dtype(), &fake_input);
276 new_item->feed.emplace_back(NodeName(input.name()), fake_input);
277 }
278 }
279 }
280 for (const auto& name_and_output : name_and_signature.second.outputs()) {
281 const TensorInfo& output = name_and_output.second;
282 if (output.has_coo_sparse()) {
283 if (gtl::InsertIfNotPresent(
284 &signature_fetch_nodes,
285 NodeName(output.coo_sparse().values_tensor_name()))) {
286 new_item->fetch.push_back(
287 NodeName(output.coo_sparse().values_tensor_name()));
288 }
289 if (gtl::InsertIfNotPresent(
290 &signature_fetch_nodes,
291 NodeName(output.coo_sparse().indices_tensor_name()))) {
292 new_item->fetch.push_back(
293 NodeName(output.coo_sparse().indices_tensor_name()));
294 }
295 if (gtl::InsertIfNotPresent(
296 &signature_fetch_nodes,
297 NodeName(output.coo_sparse().dense_shape_tensor_name()))) {
298 new_item->fetch.push_back(
299 NodeName(output.coo_sparse().dense_shape_tensor_name()));
300 }
301 } else {
302 if (gtl::InsertIfNotPresent(&signature_fetch_nodes,
303 NodeName(output.name()))) {
304 new_item->fetch.push_back(NodeName(output.name()));
305 }
306 }
307 }
308 }
309
310 for (const auto& feed : new_item->feed) {
311 if (feed.first.empty()) {
312 LOG(ERROR) << "Invalid feed node name skipping this input";
313 return nullptr;
314 } else {
315 VLOG(1) << "Will use feed node " << feed.first;
316 }
317 }
318
319 for (const auto& fetch : new_item->fetch) {
320 if (fetch.empty()) {
321 LOG(ERROR) << "Invalid fetch node name skipping this input";
322 return nullptr;
323 } else {
324 VLOG(1) << "Will use fetch node " << fetch;
325 }
326 }
327
328 if (new_item->fetch.empty()) {
329 LOG(ERROR) << "Failed to detect the fetch node(s), skipping this input";
330 return nullptr;
331 }
332
333 // TODO(yuefengz): consider handling saved_model_main_op and legacy_init_op.
334 // The reason why they are difficult to handle is because they may not intend
335 // to initialize all variables that are required to run fetch nodes. We may
336 // have to run restore op first.
337
338 // Try to find initializers from variables and tables as init ops.
339 for (const string& var_collection :
340 {"variables", "local_variables", "model_variables",
341 "trainable_variables"}) {
342 if (meta_graph.collection_def().count(var_collection) == 0) {
343 continue;
344 }
345 const CollectionDef& vars = meta_graph.collection_def().at(var_collection);
346 for (const auto& raw_var : vars.bytes_list().value()) {
347 VariableDef var;
348 var.ParseFromString(raw_var);
349 if (!var.initializer_name().empty()) {
350 new_item->init_ops.push_back(NodeName(var.initializer_name()));
351 }
352 }
353 }
354
355 if (meta_graph.collection_def().count("table_initializer") > 0) {
356 const CollectionDef& inits =
357 meta_graph.collection_def().at("table_initializer");
358 if (inits.has_node_list()) {
359 for (const auto& node : inits.node_list().value()) {
360 new_item->init_ops.push_back(NodeName(node));
361 // Tables are initialized from files, which can take a long time. Add
362 // 30 minutes to the initialization time for each table to avoid
363 // timing out.
364 // TODO(bsteiner): adjust the timeout based on the file size.
365 new_item->expected_init_time += 30 * 60;
366 }
367 }
368 }
369
370 // We keep the mapping from asset node to asset files. This should have been
371 // used as feed but since asset node is usually a constant node, we will fill
372 // the values of these constant nodes with their actual asset file paths.
373 std::unordered_map<string, string> asset_node_to_value;
374
375 // Assets file may have changed their directory, we assemble their new paths
376 // if assets_directory_override is set. We also make sure we still can
377 // access these asset files.
378 if (!cfg.assets_directory_override.empty()) {
379 if (meta_graph.collection_def().count("saved_model_assets") > 0) {
380 const CollectionDef& collection =
381 meta_graph.collection_def().at("saved_model_assets");
382 const auto& any_assets = collection.any_list().value();
383 for (const auto& any_asset : any_assets) {
384 AssetFileDef asset_file_def;
385 if (!ParseAny(any_asset, &asset_file_def, "tensorflow.AssetFileDef")
386 .ok()) {
387 LOG(ERROR) << "Failed to parse AssetFile.";
388 continue;
389 }
390 string asset_filepath = io::JoinPath(cfg.assets_directory_override,
391 asset_file_def.filename());
392 if (!FilesExist({asset_filepath}, nullptr)) {
393 LOG(ERROR) << "Can't access one or more of the asset files "
394 << asset_filepath << ", skipping this input";
395 return nullptr;
396 }
397 asset_node_to_value[NodeName(asset_file_def.tensor_info().name())] =
398 asset_filepath;
399 }
400 }
401 } else if (meta_graph.collection_def().count("asset_filepaths") > 0) {
402 const CollectionDef& file_paths =
403 meta_graph.collection_def().at("asset_filepaths");
404 std::vector<string> paths;
405 for (const auto& raw_path : file_paths.bytes_list().value()) {
406 paths.push_back(raw_path);
407 }
408 if (!FilesExist(paths, nullptr)) {
409 LOG(ERROR) << "Can't access one or more of the asset files, skipping "
410 "this input";
411 return nullptr;
412 }
413 }
414
415 if (meta_graph.collection_def().count("queue_runners") > 0) {
416 const CollectionDef& vars = meta_graph.collection_def().at("queue_runners");
417 for (const auto& raw : vars.bytes_list().value()) {
418 QueueRunnerDef queue_runner;
419 if (!queue_runner.ParseFromString(raw)) {
420 LOG(ERROR) << "Could not parse queue_runners, skipping this input";
421 return nullptr;
422 }
423 if (queue_runner.cancel_op_name().empty()) {
424 LOG(ERROR) << "Queue without a cancel op, skipping this input";
425 return nullptr;
426 }
427 new_item->queue_runners.push_back(queue_runner);
428 }
429 }
430
431 // Add each node referenced in a collection to the list of nodes to keep.
432 for (const auto& col : meta_graph.collection_def()) {
433 const CollectionDef& collection = col.second;
434 for (const string& node : collection.node_list().value()) {
435 new_item->keep_ops.push_back(NodeName(node));
436 }
437 }
438
439 for (auto& node : *new_item->graph.mutable_node()) {
440 if (IsPlaceholder(node) && node.op() != "PlaceholderWithDefault") {
441 if (node.attr().count("dtype") == 0) {
442 LOG(ERROR) << "Unknown type for placeholder " << node.name()
443 << ", skipping this input";
444 return nullptr;
445 }
446 DataType type = node.attr().at("dtype").type();
447
448 if (node.attr().count("shape") == 0) {
449 LOG(INFO) << "Unknown shape for placeholder " << node.name()
450 << ", skipping this input";
451 return nullptr;
452 }
453
454 // Replace all unknown dimensions in the placeholder's tensorshape proto
455 // with cfg.placeholder_unknown_output_shape_dim and create a tensorshape
456 // from it. We do this because in newer protos, the input placeholder
457 // shape is not empty if the shape is partially defined.
458 TensorShape shape;
459 TensorShapeProto shape_proto;
460 Status make_shape_status = ReplaceUnknownShapeDim(
461 cfg, node.attr().at("shape").shape(), &shape_proto, &shape);
462 if (!make_shape_status.ok()) {
463 LOG(ERROR) << "Invalid shape for placeholder " << node.name() << ": "
464 << make_shape_status << ", skipping this input";
465 return nullptr;
466 }
467
468 // Some placeholder nodes have a mis-match between the node
469 // attribute "shape" and a different node attribute "_output_shapes".
470 // Specifically, a shape with shape.dims() == 0 could indicate either
471 // a scalar or an unknown shape. In those cases, we check _output_shapes
472 // for additional information.
473 // This case is observed in the bnmt graphs. Have not observed any
474 // cases where there was more than 1 _output_shapes, so limit it
475 // to cases where there is only 1 _output_shapes.
476 // We only do this if cfg.placeholder_unknown_output_shape_dim has
477 // been set to avoid crashing non-BNMT graphs.
478 if ((cfg.placeholder_unknown_output_shape_dim >= 0) &&
479 (shape.dims() == 0) && (node.attr().count("_output_shapes") == 1)) {
480 const auto& output_shapes =
481 node.attr().at("_output_shapes").list().shape(0);
482
483 if (output_shapes.dim_size() != 0) {
484 shape.Clear();
485 shape_proto.clear_dim();
486
487 for (const auto& dim : output_shapes.dim()) {
488 auto size = dim.size();
489 if (size == -1) size = cfg.placeholder_unknown_output_shape_dim;
490 shape.AddDim(size);
491 shape_proto.add_dim()->set_size(size);
492 }
493 }
494 }
495
496 Tensor fake_input(type, shape);
497 InitializeTensor(type, &fake_input);
498
499 if (cfg.feed_nodes.empty()) {
500 // No specific feed nodes were given. Assume all placeholders are fed.
501 if (signature_feed_nodes.count(node.name()) == 0) {
502 new_item->feed.emplace_back(node.name(), fake_input);
503 }
504 } else if (cfg.feed_nodes.count(node.name()) > 0) {
505 // If specific feed nodes were given, only update their tensors.
506 auto it = find_if(new_item->feed.begin(), new_item->feed.end(),
507 [&node](std::pair<string, Tensor>& f) {
508 return f.first == node.name();
509 });
510 QCHECK(it != new_item->feed.end());
511 it->second = fake_input;
512 }
513
514 // Set the shape of the node in the graph. This is needed for statically
515 // inferring shapes and is a no-op when dynamically inferring shapes as
516 // the Placeholder shape will match the shape passed from new_item->feed.
517 *(node.mutable_attr()->at("shape").mutable_shape()) = shape_proto;
518 } else if (IsConstant(node)) {
519 auto it = asset_node_to_value.find(node.name());
520 if (it != asset_node_to_value.end()) {
521 auto iter = node.mutable_attr()->find("value");
522 if (iter == node.attr().end()) {
523 LOG(ERROR) << "Value attribute expected in const op for asset files";
524 return nullptr;
525 }
526 if (!iter->second.has_tensor() ||
527 iter->second.tensor().string_val_size() != 1) {
528 LOG(INFO) << "Unexpected AttrValue proto: "
529 << iter->second.DebugString();
530 return nullptr;
531 }
532 LOG(INFO) << "Using asset file " << it->second << " for node "
533 << node.name();
534 *(iter->second.mutable_tensor()->mutable_string_val(0)) = it->second;
535 }
536 }
537
538 // Erase the recorded result of any previous shape inference to start again
539 // from scratch.
540 node.mutable_attr()->erase("_output_shapes");
541
542 // Delete user specified placement if requested.
543 if (cfg.ignore_user_placement) {
544 node.clear_device();
545 }
546 // Delete colocation constraints if requested.
547 if (cfg.ignore_colocation) {
548 auto attr = node.mutable_attr();
549 auto it = attr->find("_class");
550 if (it != attr->end()) {
551 attr->erase(it);
552 }
553 }
554 }
555
556 if (meta_graph.collection_def().count("savers") > 0) {
557 const CollectionDef& savers = meta_graph.collection_def().at("savers");
558 for (const auto& raw : savers.bytes_list().value()) {
559 SaverDef saver;
560 // Skip bad savers since we don't need saves/restores to be able to run a
561 // graph.
562 if (!saver.ParseFromString(raw)) {
563 continue;
564 }
565 if (saver.filename_tensor_name().empty()) {
566 continue;
567 }
568 new_item->save_op = saver.save_tensor_name();
569 new_item->restore_op = saver.restore_op_name();
570 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
571 // Only use the first saver since it's not clear what to do if there's
572 // more than one.
573 break;
574 }
575 } else {
576 const SaverDef& saver = meta_graph.saver_def();
577 new_item->save_op = saver.save_tensor_name();
578 new_item->restore_op = saver.restore_op_name();
579 new_item->save_restore_loc_tensor = saver.filename_tensor_name();
580 }
581
582 // Instantiate all the missing attributes with their default values.
583 Status attr_status = AddDefaultAttrsToGraphDef(
584 &new_item->graph,
585 FunctionLibraryDefinition(OpRegistry::Global(),
586 new_item->graph.library()),
587 0, true);
588 if (!attr_status.ok()) {
589 LOG(ERROR) << "Failed to instantiate default attribute values: "
590 << attr_status.error_message();
591 return nullptr;
592 }
593
594 // Optimize the graph (function inlining, l1 optimizations, etc).
595 VLOG(1) << "Number of nodes in graph before OptimizeGraph: "
596 << new_item->graph.node_size();
597 Status optimize_status =
598 OptimizeGraph(new_item->graph, &new_item->graph, cfg);
599 if (!optimize_status.ok()) {
600 LOG(ERROR) << "Graph preprocessing failed: " << optimize_status;
601 return nullptr;
602 }
603 VLOG(1) << "Number of nodes in graph after OptimizeGraph: "
604 << new_item->graph.node_size();
605
606 if (cfg.prune_graph) {
607 VLOG(1) << "Pruning graph...";
608 auto status = PruneGraph(new_item.get());
609 if (!status.ok()) {
610 LOG(ERROR) << "Pruning failed: " << status.error_message();
611 return nullptr;
612 }
613 VLOG(1) << "Number of nodes in graph after pruning: "
614 << new_item->graph.node_size();
615 }
616
617 // Validate feed, fetch and init nodes
618 std::unordered_set<string> nodes;
619 for (const auto& node : new_item->graph.node()) {
620 nodes.insert(node.name());
621 }
622 for (const auto& feed : new_item->feed) {
623 if (nodes.find(feed.first) == nodes.end()) {
624 LOG(ERROR) << "Feed node " << feed.first << " doesn't exist in graph";
625 return nullptr;
626 }
627 }
628 for (const auto& fetch : new_item->fetch) {
629 if (nodes.find(fetch) == nodes.end()) {
630 LOG(ERROR) << "Fetch node " << fetch << " doesn't exist in graph";
631 return nullptr;
632 }
633 }
634 for (const auto& init : new_item->init_ops) {
635 if (nodes.find(init) == nodes.end()) {
636 LOG(ERROR) << "Init node " << init << " doesn't exist in graph";
637 return nullptr;
638 }
639 }
640 return new_item;
641 }
642
GrapplerItemFromMetaGraphDefFile(const string & id,const string & meta_graph_file,const ItemConfig & cfg)643 std::unique_ptr<GrapplerItem> GrapplerItemFromMetaGraphDefFile(
644 const string& id, const string& meta_graph_file, const ItemConfig& cfg) {
645 MetaGraphDef meta_graph;
646 if (!ReadMetaGraphDefFromFile(meta_graph_file, &meta_graph).ok()) {
647 return nullptr;
648 }
649 return GrapplerItemFromMetaGraphDef(id, meta_graph, cfg);
650 }
651
652 } // end namespace grappler
653 } // end namespace tensorflow
654