1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""SavedModel builder implementation."""
16
17from __future__ import absolute_import
18from __future__ import division
19from __future__ import print_function
20
21import functools
22import os
23
24from google.protobuf.any_pb2 import Any
25
26from tensorflow.core.framework import types_pb2
27from tensorflow.core.protobuf import meta_graph_pb2
28from tensorflow.core.protobuf import saved_model_pb2
29from tensorflow.core.protobuf import saver_pb2
30from tensorflow.python.framework import dtypes
31from tensorflow.python.framework import ops
32from tensorflow.python.lib.io import file_io
33from tensorflow.python.ops import variables
34from tensorflow.python.platform import tf_logging
35from tensorflow.python.saved_model import constants
36from tensorflow.python.saved_model import signature_def_utils
37from tensorflow.python.saved_model import utils_impl as saved_model_utils
38from tensorflow.python.training import saver as tf_saver
39from tensorflow.python.util import compat
40from tensorflow.python.util.deprecation import deprecated_args
41from tensorflow.python.util.tf_export import tf_export
42
43
44# Base class for the SavedModelBuilder that is only used by Tensorflow
45# internally. Please use tf.compat.v1.saved_model.SavedModelBuilder instead.
46class _SavedModelBuilder(object):
47  """Builds the `SavedModel` protocol buffer and saves variables and assets.
48
49  The `SavedModelBuilder` class provides functionality to build a `SavedModel`
50  protocol buffer. Specifically, this allows multiple meta graphs to be saved as
51  part of a single language-neutral `SavedModel`, while sharing variables and
52  assets.
53
54  To build a SavedModel, the first meta graph must be saved with variables.
55  Subsequent meta graphs will simply be saved with their graph definitions. If
56  assets need to be saved and written or copied to disk, they can be provided
57  when the meta graph def is added. If multiple meta graph defs are associated
58  an asset of the same name, only the first version is retained.
59
60  Each meta graph added to the SavedModel must be annotated with tags. The tags
61  provide a means to identify the specific meta graph to load and restore, along
62  with the shared set of variables and assets.
63
64  Typical usage for the `SavedModelBuilder`:
65  ```python
66  ...
67  builder = tf.saved_model.Builder(export_dir)
68
69  with tf.Session(graph=tf.Graph()) as sess:
70    ...
71    builder.add_meta_graph_and_variables(sess,
72                                    ["foo-tag"],
73                                    signature_def_map=foo_signatures,
74                                    assets_list=foo_assets)
75  ...
76
77  with tf.Session(graph=tf.Graph()) as sess:
78    ...
79    builder.add_meta_graph(["bar-tag", "baz-tag"])
80  ...
81
82  builder.save()
83  ```
84
85  Note: This function will only be available through the v1 compatibility
86  library as tf.compat.v1.saved_model.builder.SavedModelBuilder or
87  tf.compat.v1.saved_model.Builder. Tensorflow 2.0 will introduce a new
88  object-based method of creating SavedModels.
89  """
90
91  def __init__(self, export_dir):
92    self._saved_model = saved_model_pb2.SavedModel()
93    self._saved_model.saved_model_schema_version = (
94        constants.SAVED_MODEL_SCHEMA_VERSION)
95
96    self._export_dir = export_dir
97    if file_io.file_exists(export_dir):
98      if file_io.list_directory(export_dir):
99        raise AssertionError(
100            "Export directory already exists, and isn't empty. Please choose "
101            "a different export directory, or delete all the contents of the "
102            "specified directory: %s" % export_dir)
103    else:
104      file_io.recursive_create_dir(self._export_dir)
105
106    # Boolean to track whether variables and assets corresponding to the
107    # SavedModel have been saved. Specifically, the first meta graph to be added
108    # MUST use the add_meta_graph_and_variables() API. Subsequent add operations
109    # on the SavedModel MUST use the add_meta_graph() API which does not save
110    # weights.
111    self._has_saved_variables = False
112
113  def _save_and_write_assets(self, meta_graph_def, assets_list=None):
114    """Saves asset to the meta graph and writes asset files to disk.
115
116    Args:
117      meta_graph_def: The meta graph def to which the assets will be added.
118      assets_list: The list where the asset paths are setup.
119    """
120    # Creates a function that adds assets into the meta graph def.
121    write_fn = functools.partial(_add_asset_to_metagraph, meta_graph_def)
122    asset_filename_map = _maybe_save_assets(write_fn, assets_list)
123
124    # Return if there are no assets to write.
125    if not asset_filename_map:
126      tf_logging.info("No assets to write.")
127      return
128
129    # Copy assets from source path to destination path.
130    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
131
132  def _tag_and_add_meta_graph(self, meta_graph_def, tags, signature_def_map):
133    """Tags the meta graph def and adds it to the SavedModel.
134
135    Tags the meta graph def with the supplied tags, adds signature defs to it if
136    provided and appends the meta graph def to the SavedModel proto.
137
138    Args:
139      meta_graph_def: The meta graph def to add to the SavedModel.
140      tags: The set of tags to annotate the meta graph def with.
141      signature_def_map: The map of signature defs to be added to the meta graph
142          def.
143    """
144    for tag in tags:
145      meta_graph_def.meta_info_def.tags.append(tag)
146
147    if signature_def_map is not None:
148      for key in signature_def_map:
149        meta_graph_def.signature_def[key].CopyFrom(signature_def_map[key])
150
151    proto_meta_graph_def = self._saved_model.meta_graphs.add()
152    proto_meta_graph_def.CopyFrom(meta_graph_def)
153
154  def _validate_tensor_info(self, tensor_info):
155    """Validates the `TensorInfo` proto.
156
157    Checks if the `encoding` (`name` or `coo_sparse`) and `dtype` fields exist
158    and are non-empty.
159
160    Args:
161      tensor_info: `TensorInfo` protocol buffer to validate.
162
163    Raises:
164      AssertionError: If the `name` or `dtype` fields of the supplied
165          `TensorInfo` proto are not populated.
166    """
167    if tensor_info is None:
168      raise AssertionError(
169          "All TensorInfo protos used in the SignatureDefs must have the name "
170          "and dtype fields set.")
171    if tensor_info.WhichOneof("encoding") is None:
172      # TODO(soergel) validate each of the fields of coo_sparse
173      raise AssertionError(
174          "All TensorInfo protos used in the SignatureDefs must have one of "
175          "the 'encoding' fields (e.g., name or coo_sparse) set: %s"
176          % tensor_info)
177    if tensor_info.dtype is types_pb2.DT_INVALID:
178      raise AssertionError(
179          "All TensorInfo protos used in the SignatureDefs must have the dtype "
180          "field set: %s" % tensor_info)
181
182  def _validate_signature_def_map(self, signature_def_map):
183    """Validates the `SignatureDef` entries in the signature def map.
184
185    Validation of entries in the signature def map includes ensuring that the
186    `name` and `dtype` fields of the TensorInfo protos of the `inputs` and
187    `outputs` of each `SignatureDef` are populated. Also ensures that reserved
188    SigantureDef keys for the initialization and train ops are not used.
189
190    Args:
191      signature_def_map: The map of signature defs to be validated.
192
193    Raises:
194      AssertionError: If a TensorInfo is not valid.
195      KeyError: If a reserved signature key is used in the map.
196    """
197    for signature_def_key in signature_def_map:
198      signature_def = signature_def_map[signature_def_key]
199      inputs = signature_def.inputs
200      outputs = signature_def.outputs
201      for inputs_key in inputs:
202        self._validate_tensor_info(inputs[inputs_key])
203      for outputs_key in outputs:
204        self._validate_tensor_info(outputs[outputs_key])
205    if constants.INIT_OP_SIGNATURE_KEY in signature_def_map:
206      raise KeyError(
207          "SignatureDef map key \"{}\" is reserved for initialization. Please "
208          "use a different key.".format(constants.INIT_OP_SIGNATURE_KEY))
209    if constants.TRAIN_OP_SIGNATURE_KEY in signature_def_map:
210      raise KeyError(
211          "SignatureDef map key \"{}\" is reserved for the train op. Please "
212          "use a different key.".format(constants.TRAIN_OP_SIGNATURE_KEY))
213
214  def _maybe_create_saver(self, saver=None):
215    """Creates a sharded saver if one does not already exist."""
216    if not saver:
217      # Initialize a saver to generate a sharded output for all saveables in the
218      # current scope.
219      saver = tf_saver.Saver(
220          variables._all_saveable_objects(),  # pylint: disable=protected-access
221          sharded=True,
222          write_version=saver_pb2.SaverDef.V2,
223          allow_empty=True)
224    return saver
225
226  def add_meta_graph(self,
227                     tags,
228                     signature_def_map=None,
229                     assets_list=None,
230                     clear_devices=False,
231                     init_op=None,
232                     train_op=None,
233                     saver=None):
234    """Adds the current meta graph to the SavedModel.
235
236    Creates a Saver in the current scope and uses the Saver to export the meta
237    graph def. Invoking this API requires the `add_meta_graph_and_variables()`
238    API to have been invoked before.
239
240    Args:
241      tags: The set of tags to annotate the meta graph def with.
242      signature_def_map: The map of signature defs to be added to the meta graph
243          def.
244      assets_list: Assets to be saved with SavedModel. Note
245          that this list should be a subset of the assets saved as part of
246          the first meta graph in the SavedModel.
247      clear_devices: Set to true if the device info on the default graph should
248          be cleared.
249      init_op: Op or group of ops to execute when the graph is loaded. Note
250          that when the init_op is specified it is run after the restore op at
251          load-time.
252      train_op: Op or group of opts that trains the model when run. This will
253        not be run automatically when the graph is loaded, instead saved in
254        a SignatureDef accessible through the exported MetaGraph.
255      saver: An instance of tf.train.Saver that will be used to export the
256        metagraph. If None, a sharded Saver that restores all variables will
257        be used.
258
259    Raises:
260      AssertionError: If the variables for the SavedModel have not been saved
261          yet, or if the graph already contains one or more legacy init ops.
262    """
263    if not self._has_saved_variables:
264      raise AssertionError(
265          "Graph state including variables and assets has not been saved yet. "
266          "Please invoke `add_meta_graph_and_variables()` first.")
267
268    # Validate the signature def map to ensure all included TensorInfos are
269    # properly populated.
270    signature_def_map = signature_def_map or {}
271    self._validate_signature_def_map(signature_def_map)
272
273    # Create a SignatureDef pointing to the graph initialization op, which will
274    # be added to the MetaGraphDef.
275    _add_op_to_signature_def_map(signature_def_map, init_op,
276                                 constants.INIT_OP_SIGNATURE_KEY)
277    _add_op_to_signature_def_map(signature_def_map, train_op,
278                                 constants.TRAIN_OP_SIGNATURE_KEY)
279
280    saver = self._maybe_create_saver(saver)
281
282    # The graph almost certainly previously contained at least one Saver, and
283    # possibly several (e.g. one for loading a pretrained embedding, and another
284    # for the model weights).  Removing the preexisting ones was the
285    # motivation for the clear_extraneous_savers option, but it turns out that
286    # there are edge cases where that option breaks the graph.  Until that is
287    # resolved, we just leave the option set to False for now.
288    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
289    meta_graph_def = saver.export_meta_graph(
290        clear_devices=clear_devices, strip_default_attrs=True)
291
292    # Save asset files and write them to disk, if any.
293    self._save_and_write_assets(meta_graph_def, assets_list)
294
295    # Tag the meta graph def and add it to the SavedModel.
296    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
297
298  def add_meta_graph_and_variables(self,
299                                   sess,
300                                   tags,
301                                   signature_def_map=None,
302                                   assets_list=None,
303                                   clear_devices=False,
304                                   init_op=None,
305                                   train_op=None,
306                                   strip_default_attrs=False,
307                                   saver=None):
308    # pylint: disable=line-too-long
309    """Adds the current meta graph to the SavedModel and saves variables.
310
311    Creates a Saver to save the variables from the provided session. Exports the
312    corresponding meta graph def. This function assumes that the variables to be
313    saved have been initialized. For a given `SavedModelBuilder`, this API must
314    be called exactly once and for the first meta graph to save. For subsequent
315    meta graph defs to be added, the `add_meta_graph()` API must be used.
316
317    Args:
318      sess: The TensorFlow session from which to save the meta graph and
319        variables.
320      tags: The set of tags with which to save the meta graph.
321      signature_def_map: The map of signature def map to add to the meta graph
322        def.
323      assets_list: Assets to be saved with SavedModel.
324      clear_devices: Set to true if the device info on the default graph should
325          be cleared.
326      init_op: Op or group of ops to execute when the graph is loaded. Note
327          that when the init_op is specified it is run after the restore op at
328          load-time.
329      train_op: Op or group of ops that trains the model when run. This will
330        not be run automatically when the graph is loaded, instead saved in
331        a SignatureDef accessible through the exported MetaGraph.
332      strip_default_attrs: Boolean. If `True`, default-valued attributes will be
333        removed from the NodeDefs. For a detailed guide, see
334        [Stripping Default-Valued Attributes](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/saved_model/README.md#stripping-default-valued-attributes).
335      saver: An instance of tf.train.Saver that will be used to export the
336        metagraph and save variables. If None, a sharded Saver that restores
337        all variables will be used.
338
339    """
340    # pylint: enable=line-too-long
341    if self._has_saved_variables:
342      raise AssertionError("Graph state including variables and assets has "
343                           "already been saved. Please invoke "
344                           "`add_meta_graph()` instead.")
345
346    # Validate the signature def map to ensure all included TensorInfos are
347    # properly populated.
348    signature_def_map = signature_def_map or {}
349    self._validate_signature_def_map(signature_def_map)
350
351    # Create a SignatureDef pointing to the graph initialization op, which will
352    # be added to the MetaGraphDef.
353    _add_op_to_signature_def_map(signature_def_map, init_op,
354                                 constants.INIT_OP_SIGNATURE_KEY)
355    _add_op_to_signature_def_map(signature_def_map, train_op,
356                                 constants.TRAIN_OP_SIGNATURE_KEY)
357
358    saved_model_utils.get_or_create_variables_dir(self._export_dir)
359    variables_path = saved_model_utils.get_variables_path(self._export_dir)
360
361    saver = self._maybe_create_saver(saver)
362
363    # Save the variables. Also, disable writing the checkpoint state proto. The
364    # file is not used during SavedModel loading. In addition, since a
365    # SavedModel can be copied or moved, this avoids the checkpoint state to
366    # become outdated.
367    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
368
369    # Export the meta graph def.
370
371    # The graph almost certainly previously contained at least one Saver, and
372    # possibly several (e.g. one for loading a pretrained embedding, and another
373    # for the model weights).  Removing the preexisting ones was the
374    # motivation for the clear_extraneous_savers option, but it turns out that
375    # there are edge cases where that option breaks the graph.  Until that is
376    # resolved, we just leave the option set to False for now.
377    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
378    meta_graph_def = saver.export_meta_graph(
379        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
380
381    # Save asset files and write them to disk, if any.
382    self._save_and_write_assets(meta_graph_def, assets_list)
383
384    # Tag the meta graph def and add it to the SavedModel.
385    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
386
387    # Mark this instance of SavedModel as having saved variables, such that
388    # subsequent attempts to save variables will fail.
389    self._has_saved_variables = True
390
391  def save(self, as_text=False):
392    """Writes a `SavedModel` protocol buffer to disk.
393
394    The function writes the SavedModel protocol buffer to the export directory
395    in serialized format.
396
397    Args:
398      as_text: Writes the SavedModel protocol buffer in text format to disk.
399
400    Returns:
401      The path to which the SavedModel protocol buffer was written.
402    """
403    if not file_io.file_exists(self._export_dir):
404      file_io.recursive_create_dir(self._export_dir)
405
406    if as_text:
407      path = os.path.join(
408          compat.as_bytes(self._export_dir),
409          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PBTXT))
410      file_io.write_string_to_file(path, str(self._saved_model))
411    else:
412      path = os.path.join(
413          compat.as_bytes(self._export_dir),
414          compat.as_bytes(constants.SAVED_MODEL_FILENAME_PB))
415      file_io.write_string_to_file(path, self._saved_model.SerializeToString())
416    tf_logging.info("SavedModel written to: %s", compat.as_text(path))
417
418    return path
419
420
421@tf_export(v1=["saved_model.Builder", "saved_model.builder.SavedModelBuilder"])  # pylint: disable=missing-docstring
422class SavedModelBuilder(_SavedModelBuilder):
423  __doc__ = _SavedModelBuilder.__doc__.replace("assets_list",
424                                               "assets_collection")
425
426  def __init__(self, export_dir):
427    super(SavedModelBuilder, self).__init__(export_dir=export_dir)
428
429  def _add_collections(self, assets_collection, main_op, train_op):
430    """Add asset and op collections to be saved."""
431    # Save asset files and write them to disk, if any.
432    self._save_and_write_assets(assets_collection)
433
434    self._maybe_add_main_op(main_op)
435
436    self._add_train_op(train_op)
437
438  def _save_and_write_assets(self, assets_collection_to_add=None):
439    """Saves asset to the meta graph and writes asset files to disk.
440
441    Args:
442      assets_collection_to_add: The collection where the asset paths are setup.
443    """
444    # Add assets to the collection with key `constants.ASSETS_KEY`, in the
445    # graph.
446    asset_filename_map = _maybe_save_assets(_add_asset_to_collection,
447                                            assets_collection_to_add)
448
449    # Return if there are no assets to write.
450    if not asset_filename_map:
451      tf_logging.info("No assets to write.")
452      return
453
454    # Copy assets from source path to destination path.
455    copy_assets_to_destination_dir(asset_filename_map, self._export_dir)
456
457  def _maybe_add_main_op(self, main_op):
458    """Adds main op to the SavedModel.
459
460    Args:
461      main_op: Main op to run as part of graph initialization. If None, no main
462        op will be added to the graph.
463
464    Raises:
465      TypeError: if main op is provided but is not of type `Operation`.
466      ValueError: if the Graph already contains an init op.
467    """
468    if main_op is None:
469      return
470
471    if not isinstance(main_op, ops.Operation):
472      raise TypeError("main_op needs to be an Operation: %r" % main_op)
473
474    # Validate that no other init ops have been added to this graph already.
475    # We check main_op and legacy_init_op for thoroughness and explicitness.
476    for init_op_key in (constants.MAIN_OP_KEY, constants.LEGACY_INIT_OP_KEY):
477      if ops.get_collection(init_op_key):
478        raise ValueError(
479            "Graph already contains one or more main ops under the "
480            "collection {}.".format(init_op_key))
481
482    ops.add_to_collection(constants.MAIN_OP_KEY, main_op)
483
484  def _add_train_op(self, train_op):
485    """Add train op to the SavedModel.
486
487    Note that this functionality is in development, and liable to be
488    moved elsewhere.
489
490    Args:
491      train_op: Op or group of ops that are used for training. These are stored
492        as a collection with key TRAIN_OP_KEY, but not executed.
493
494    Raises:
495      TypeError if Train op is not of type `Operation`.
496    """
497    if train_op is not None:
498      if (not isinstance(train_op, ops.Tensor) and
499          not isinstance(train_op, ops.Operation)):
500        raise TypeError("train_op needs to be a Tensor or Op: %r" % train_op)
501      ops.add_to_collection(constants.TRAIN_OP_KEY, train_op)
502
503  @deprecated_args(None,
504                   "Pass your op to the equivalent parameter main_op instead.",
505                   "legacy_init_op")
506  def add_meta_graph(self,
507                     tags,
508                     signature_def_map=None,
509                     assets_collection=None,
510                     legacy_init_op=None,
511                     clear_devices=False,
512                     main_op=None,
513                     strip_default_attrs=False,
514                     saver=None):
515    if not self._has_saved_variables:
516      raise AssertionError(
517          "Graph state including variables and assets has not been saved yet. "
518          "Please invoke `add_meta_graph_and_variables()` first.")
519
520    # Validate the signature def map to ensure all included TensorInfos are
521    # properly populated.
522    signature_def_map = signature_def_map or {}
523    self._validate_signature_def_map(signature_def_map)
524
525    # legacy_init_op is deprecated, and going away in TF 2.0.
526    # Re-mapping to main_op, as treatment is identical regardless.
527    main_op = main_op or legacy_init_op
528
529    # Add assets and ops
530    self._add_collections(assets_collection, main_op, None)
531
532    saver = self._maybe_create_saver(saver)
533
534    # The graph almost certainly previously contained at least one Saver, and
535    # possibly several (e.g. one for loading a pretrained embedding, and another
536    # for the model weights).  Removing the preexisting ones was the
537    # motivation for the clear_extraneous_savers option, but it turns out that
538    # there are edge cases where that option breaks the graph.  Until that is
539    # resolved, we just leave the option set to False for now.
540    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
541    meta_graph_def = saver.export_meta_graph(
542        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
543
544    # Tag the meta graph def and add it to the SavedModel.
545    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
546
547  @deprecated_args(None,
548                   "Pass your op to the equivalent parameter main_op instead.",
549                   "legacy_init_op")
550  def add_meta_graph_and_variables(self,
551                                   sess,
552                                   tags,
553                                   signature_def_map=None,
554                                   assets_collection=None,
555                                   legacy_init_op=None,
556                                   clear_devices=False,
557                                   main_op=None,
558                                   strip_default_attrs=False,
559                                   saver=None):
560    if self._has_saved_variables:
561      raise AssertionError("Graph state including variables and assets has "
562                           "already been saved. Please invoke "
563                           "`add_meta_graph()` instead.")
564
565    # Validate the signature def map to ensure all included TensorInfos are
566    # properly populated.
567    signature_def_map = signature_def_map or {}
568    self._validate_signature_def_map(signature_def_map)
569
570    # legacy_init_op is deprecated, and going away in TF 2.0.
571    # Re-mapping to main_op, as treatment is identical regardless.
572    main_op = main_op or legacy_init_op
573
574    # Add assets and ops
575    self._add_collections(assets_collection, main_op, None)
576
577    saved_model_utils.get_or_create_variables_dir(self._export_dir)
578    variables_path = saved_model_utils.get_variables_path(self._export_dir)
579
580    saver = self._maybe_create_saver(saver)
581
582    # Save the variables. Also, disable writing the checkpoint state proto. The
583    # file is not used during SavedModel loading. In addition, since a
584    # SavedModel can be copied or moved, this avoids the checkpoint state to
585    # become outdated.
586    saver.save(sess, variables_path, write_meta_graph=False, write_state=False)
587
588    # Export the meta graph def.
589
590    # The graph almost certainly previously contained at least one Saver, and
591    # possibly several (e.g. one for loading a pretrained embedding, and another
592    # for the model weights).  Removing the preexisting ones was the
593    # motivation for the clear_extraneous_savers option, but it turns out that
594    # there are edge cases where that option breaks the graph.  Until that is
595    # resolved, we just leave the option set to False for now.
596    # TODO(soergel): Reinstate clear_extraneous_savers=True when possible.
597    meta_graph_def = saver.export_meta_graph(
598        clear_devices=clear_devices, strip_default_attrs=strip_default_attrs)
599
600    # Tag the meta graph def and add it to the SavedModel.
601    self._tag_and_add_meta_graph(meta_graph_def, tags, signature_def_map)
602
603    # Mark this instance of SavedModel as having saved variables, such that
604    # subsequent attempts to save variables will fail.
605    self._has_saved_variables = True
606
607  add_meta_graph.__doc__ = _SavedModelBuilder.add_meta_graph.__doc__.replace(
608      "assets_list", "assets_collection")
609  add_meta_graph_and_variables.__doc__ = \
610      _SavedModelBuilder.add_meta_graph_and_variables.__doc__.replace(
611          "assets_list", "assets_collection")
612
613
614def _maybe_save_assets(write_fn, assets_to_add=None):
615  """Saves assets to the meta graph.
616
617  Args:
618    write_fn: A function callback that writes asset into meta graph.
619    assets_to_add: The list where the asset paths are setup.
620
621  Returns:
622    A dict of asset basenames for saving to the original full path to the asset.
623
624  Raises:
625    ValueError: Indicating an invalid filepath tensor.
626  """
627  # Map of target file names to original filenames
628  asset_filename_map = {}
629
630  if assets_to_add is None:
631    tf_logging.info("No assets to save.")
632    return asset_filename_map
633
634  # Iterate over the supplied assets, build the `AssetFile` proto and add them
635  # to the meta graph.
636  for asset_tensor in assets_to_add:
637    asset_source_filepath = _asset_path_from_tensor(asset_tensor)
638    if not asset_source_filepath:
639      raise ValueError("Invalid asset filepath tensor %s" % asset_tensor)
640
641    asset_filename = get_asset_filename_to_add(
642        asset_source_filepath, asset_filename_map)
643
644    # Call the passed-in function that builds AssetFileDef proto and adds it
645    # to either the collection or asset_file_def field of the meta graph.
646    # Note that this should be done even when the file is a duplicate of an
647    # already-added file, as the tensor reference should still exist.
648    write_fn(asset_filename, asset_tensor)
649
650    # In the cases where we are adding a duplicate, this will result in the
651    # last of the filepaths being the one used for copying the file to the
652    # SavedModel. Since the files in question are the same, it doesn't matter
653    # either way.
654    asset_filename_map[asset_filename] = asset_source_filepath
655
656  tf_logging.info("Assets added to graph.")
657  return asset_filename_map
658
659
660def get_asset_filename_to_add(asset_filepath, asset_filename_map):
661  """Get a unique basename to add to the SavedModel if this file is unseen.
662
663  Assets come from users as full paths, and we save them out to the
664  SavedModel as basenames. In some cases, the basenames collide. Here,
665  we dedupe asset basenames by first checking if the file is the same,
666  and, if different, generate and return an index-suffixed basename
667  that can be used to add the asset to the SavedModel.
668
669  Args:
670    asset_filepath: the full path to the asset that is being saved
671    asset_filename_map: a dict of filenames used for saving the asset in
672      the SavedModel to full paths from which the filenames were derived.
673
674  Returns:
675    Uniquified filename string if the file is not a duplicate, or the original
676    filename if the file has already been seen and saved.
677  """
678  asset_filename = os.path.basename(asset_filepath)
679
680  if asset_filename not in asset_filename_map:
681    # This is an unseen asset. Safe to add.
682    return asset_filename
683
684  other_asset_filepath = asset_filename_map[asset_filename]
685  if other_asset_filepath == asset_filepath:
686    # This is the same file, stored twice in the list. No need
687    # to make unique.
688    return asset_filename
689
690  # Else, asset_filename is in the map, and the filepath is different. Dedupe.
691  if not file_io.filecmp(asset_filepath, other_asset_filepath):
692    # Files are different; dedupe filenames.
693    return _get_unique_asset_filename(asset_filename, asset_filename_map)
694
695  # Files are the same; don't make unique.
696  return asset_filename
697
698
699def _get_unique_asset_filename(asset_filename, asset_filename_map):
700  i = 1
701  unique_filename = asset_filename
702  while unique_filename in asset_filename_map:
703    unique_filename = compat.as_bytes("_").join(
704        [compat.as_bytes(asset_filename), compat.as_bytes(str(i))])
705    i += 1
706  return unique_filename
707
708
709def _asset_path_from_tensor(path_tensor):
710  """Returns the filepath value stored in constant `path_tensor`.
711
712  Args:
713    path_tensor: Tensor of a file-path.
714
715  Returns:
716    The string value i.e. path of the tensor, if valid.
717
718  Raises:
719    TypeError if tensor does not match expected op type, dtype or value.
720  """
721  if not isinstance(path_tensor, ops.Tensor):
722    raise TypeError("Asset path tensor must be a Tensor.")
723  if path_tensor.op.type != "Const":
724    raise TypeError("Asset path tensor must be of type constant.")
725  if path_tensor.dtype != dtypes.string:
726    raise TypeError("Asset path tensor must be of dtype string.")
727  str_values = path_tensor.op.get_attr("value").string_val
728  if len(str_values) != 1:
729    raise TypeError("Asset path tensor must be a scalar.")
730  return str_values[0]
731
732
733def _add_asset_to_metagraph(meta_graph_def, asset_filename, asset_tensor):
734  """Builds an asset proto and adds it to the meta graph def.
735
736  Args:
737    meta_graph_def: The meta graph def to which the asset will be added.
738    asset_filename: The filename of the asset to be added.
739    asset_tensor: The asset tensor used to populate the tensor info of the asset
740      proto.
741  """
742  asset_proto = meta_graph_def.asset_file_def.add()
743  asset_proto.filename = asset_filename
744  asset_proto.tensor_info.name = asset_tensor.name
745
746
747def copy_assets_to_destination_dir(asset_filename_map, destination_dir):
748  """Copy all assets from source path to destination path."""
749  assets_destination_dir = saved_model_utils.get_or_create_assets_dir(
750      destination_dir)
751
752  # Copy each asset from source path to destination path.
753  for asset_basename, asset_source_filepath in asset_filename_map.items():
754    asset_destination_filepath = os.path.join(
755        compat.as_bytes(assets_destination_dir),
756        compat.as_bytes(asset_basename))
757
758    # Only copy the asset file to the destination if it does not already
759    # exist. This is to ensure that an asset with the same name defined as
760    # part of multiple graphs is only copied the first time.
761    if not file_io.file_exists(asset_destination_filepath):
762      file_io.copy(asset_source_filepath, asset_destination_filepath)
763
764  tf_logging.info("Assets written to: %s",
765                  compat.as_text(assets_destination_dir))
766
767
768def _add_asset_to_collection(asset_filename, asset_tensor):
769  """Builds an asset proto and adds it to the asset collection of the graph.
770
771  Args:
772    asset_filename: The filename of the asset to be added.
773    asset_tensor: The asset tensor used to populate the tensor info of the
774        asset proto.
775  """
776  asset_proto = meta_graph_pb2.AssetFileDef()
777  asset_proto.filename = asset_filename
778  asset_proto.tensor_info.name = asset_tensor.name
779
780  asset_any_proto = Any()
781  asset_any_proto.Pack(asset_proto)
782  ops.add_to_collection(constants.ASSETS_KEY, asset_any_proto)
783
784
785def _add_op_to_signature_def_map(signature_def_map, op, key):
786  if op is not None:
787    signature_def_map[key] = signature_def_utils.op_signature_def(op, key)
788