1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import binascii
21import codecs
22import importlib
23import marshal
24import os
25import re
26import sys
27import threading
28import time
29import types as python_types
30import weakref
31
32import numpy as np
33import six
34from tensorflow.python.keras.utils import tf_contextlib
35from tensorflow.python.keras.utils import tf_inspect
36from tensorflow.python.util import nest
37from tensorflow.python.util import tf_decorator
38from tensorflow.python.util.tf_export import keras_export
39
40_GLOBAL_CUSTOM_OBJECTS = {}
41_GLOBAL_CUSTOM_NAMES = {}
42
43# Flag that determines whether to skip the NotImplementedError when calling
44# get_config in custom models and layers. This is only enabled when saving to
45# SavedModel, when the config isn't required.
46_SKIP_FAILED_SERIALIZATION = False
47# If a layer does not have a defined config, then the returned config will be a
48# dictionary with the below key.
49_LAYER_UNDEFINED_CONFIG_KEY = 'layer was saved without config'
50
51
52@keras_export('keras.utils.custom_object_scope',  # pylint: disable=g-classes-have-attributes
53              'keras.utils.CustomObjectScope')
54class CustomObjectScope(object):
55  """Exposes custom classes/functions to Keras deserialization internals.
56
57  Under a scope `with custom_object_scope(objects_dict)`, Keras methods such
58  as `tf.keras.models.load_model` or `tf.keras.models.model_from_config`
59  will be able to deserialize any custom object referenced by a
60  saved config (e.g. a custom layer or metric).
61
62  Example:
63
64  Consider a custom regularizer `my_regularizer`:
65
66  ```python
67  layer = Dense(3, kernel_regularizer=my_regularizer)
68  config = layer.get_config()  # Config contains a reference to `my_regularizer`
69  ...
70  # Later:
71  with custom_object_scope({'my_regularizer': my_regularizer}):
72    layer = Dense.from_config(config)
73  ```
74
75  Args:
76      *args: Dictionary or dictionaries of `{name: object}` pairs.
77  """
78
79  def __init__(self, *args):
80    self.custom_objects = args
81    self.backup = None
82
83  def __enter__(self):
84    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
85    for objects in self.custom_objects:
86      _GLOBAL_CUSTOM_OBJECTS.update(objects)
87    return self
88
89  def __exit__(self, *args, **kwargs):
90    _GLOBAL_CUSTOM_OBJECTS.clear()
91    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
92
93
94@keras_export('keras.utils.get_custom_objects')
95def get_custom_objects():
96  """Retrieves a live reference to the global dictionary of custom objects.
97
98  Updating and clearing custom objects using `custom_object_scope`
99  is preferred, but `get_custom_objects` can
100  be used to directly access the current collection of custom objects.
101
102  Example:
103
104  ```python
105  get_custom_objects().clear()
106  get_custom_objects()['MyObject'] = MyObject
107  ```
108
109  Returns:
110      Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
111  """
112  return _GLOBAL_CUSTOM_OBJECTS
113
114
115# Store a unique, per-object ID for shared objects.
116#
117# We store a unique ID for each object so that we may, at loading time,
118# re-create the network properly.  Without this ID, we would have no way of
119# determining whether a config is a description of a new object that
120# should be created or is merely a reference to an already-created object.
121SHARED_OBJECT_KEY = 'shared_object_id'
122
123
124SHARED_OBJECT_DISABLED = threading.local()
125SHARED_OBJECT_LOADING = threading.local()
126SHARED_OBJECT_SAVING = threading.local()
127
128
129# Attributes on the threadlocal variable must be set per-thread, thus we
130# cannot initialize these globally. Instead, we have accessor functions with
131# default values.
132def _shared_object_disabled():
133  """Get whether shared object handling is disabled in a threadsafe manner."""
134  return getattr(SHARED_OBJECT_DISABLED, 'disabled', False)
135
136
137def _shared_object_loading_scope():
138  """Get the current shared object saving scope in a threadsafe manner."""
139  return getattr(SHARED_OBJECT_LOADING, 'scope', NoopLoadingScope())
140
141
142def _shared_object_saving_scope():
143  """Get the current shared object saving scope in a threadsafe manner."""
144  return getattr(SHARED_OBJECT_SAVING, 'scope', None)
145
146
147class DisableSharedObjectScope(object):
148  """A context manager for disabling handling of shared objects.
149
150  Disables shared object handling for both saving and loading.
151
152  Created primarily for use with `clone_model`, which does extra surgery that
153  is incompatible with shared objects.
154  """
155
156  def __enter__(self):
157    SHARED_OBJECT_DISABLED.disabled = True
158    self._orig_loading_scope = _shared_object_loading_scope()
159    self._orig_saving_scope = _shared_object_saving_scope()
160
161  def __exit__(self, *args, **kwargs):
162    SHARED_OBJECT_DISABLED.disabled = False
163    SHARED_OBJECT_LOADING.scope = self._orig_loading_scope
164    SHARED_OBJECT_SAVING.scope = self._orig_saving_scope
165
166
167class NoopLoadingScope(object):
168  """The default shared object loading scope. It does nothing.
169
170  Created to simplify serialization code that doesn't care about shared objects
171  (e.g. when serializing a single object).
172  """
173
174  def get(self, unused_object_id):
175    return None
176
177  def set(self, object_id, obj):
178    pass
179
180
181class SharedObjectLoadingScope(object):
182  """A context manager for keeping track of loaded objects.
183
184  During the deserialization process, we may come across objects that are
185  shared across multiple layers. In order to accurately restore the network
186  structure to its original state, `SharedObjectLoadingScope` allows us to
187  re-use shared objects rather than cloning them.
188  """
189
190  def __enter__(self):
191    if _shared_object_disabled():
192      return NoopLoadingScope()
193
194    global SHARED_OBJECT_LOADING
195    SHARED_OBJECT_LOADING.scope = self
196    self._obj_ids_to_obj = {}
197    return self
198
199  def get(self, object_id):
200    """Given a shared object ID, returns a previously instantiated object.
201
202    Args:
203      object_id: shared object ID to use when attempting to find already-loaded
204        object.
205
206    Returns:
207      The object, if we've seen this ID before. Else, `None`.
208    """
209    # Explicitly check for `None` internally to make external calling code a
210    # bit cleaner.
211    if object_id is None:
212      return
213    return self._obj_ids_to_obj.get(object_id)
214
215  def set(self, object_id, obj):
216    """Stores an instantiated object for future lookup and sharing."""
217    if object_id is None:
218      return
219    self._obj_ids_to_obj[object_id] = obj
220
221  def __exit__(self, *args, **kwargs):
222    global SHARED_OBJECT_LOADING
223    SHARED_OBJECT_LOADING.scope = NoopLoadingScope()
224
225
226class SharedObjectConfig(dict):
227  """A configuration container that keeps track of references.
228
229  `SharedObjectConfig` will automatically attach a shared object ID to any
230  configs which are referenced more than once, allowing for proper shared
231  object reconstruction at load time.
232
233  In most cases, it would be more proper to subclass something like
234  `collections.UserDict` or `collections.Mapping` rather than `dict` directly.
235  Unfortunately, python's json encoder does not support `Mapping`s. This is
236  important functionality to retain, since we are dealing with serialization.
237
238  We should be safe to subclass `dict` here, since we aren't actually
239  overriding any core methods, only augmenting with a new one for reference
240  counting.
241  """
242
243  def __init__(self, base_config, object_id, **kwargs):
244    self.ref_count = 1
245    self.object_id = object_id
246    super(SharedObjectConfig, self).__init__(base_config, **kwargs)
247
248  def increment_ref_count(self):
249    # As soon as we've seen the object more than once, we want to attach the
250    # shared object ID. This allows us to only attach the shared object ID when
251    # it's strictly necessary, making backwards compatibility breakage less
252    # likely.
253    if self.ref_count == 1:
254      self[SHARED_OBJECT_KEY] = self.object_id
255    self.ref_count += 1
256
257
258class SharedObjectSavingScope(object):
259  """Keeps track of shared object configs when serializing."""
260
261  def __enter__(self):
262    if _shared_object_disabled():
263      return None
264
265    global SHARED_OBJECT_SAVING
266
267    # Serialization can happen at a number of layers for a number of reasons.
268    # We may end up with a case where we're opening a saving scope within
269    # another saving scope. In that case, we'd like to use the outermost scope
270    # available and ignore inner scopes, since there is not (yet) a reasonable
271    # use case for having these nested and distinct.
272    if _shared_object_saving_scope() is not None:
273      self._passthrough = True
274      return _shared_object_saving_scope()
275    else:
276      self._passthrough = False
277
278    SHARED_OBJECT_SAVING.scope = self
279    self._shared_objects_config = weakref.WeakKeyDictionary()
280    self._next_id = 0
281    return self
282
283  def get_config(self, obj):
284    """Gets a `SharedObjectConfig` if one has already been seen for `obj`.
285
286    Args:
287      obj: The object for which to retrieve the `SharedObjectConfig`.
288
289    Returns:
290      The SharedObjectConfig for a given object, if already seen. Else,
291        `None`.
292    """
293    try:
294      shared_object_config = self._shared_objects_config[obj]
295    except (TypeError, KeyError):
296      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
297      # that has not overridden `__hash__`), a `TypeError` will be thrown.
298      # We'll just continue on without shared object support.
299      return None
300    shared_object_config.increment_ref_count()
301    return shared_object_config
302
303  def create_config(self, base_config, obj):
304    """Create a new SharedObjectConfig for a given object."""
305    shared_object_config = SharedObjectConfig(base_config, self._next_id)
306    self._next_id += 1
307    try:
308      self._shared_objects_config[obj] = shared_object_config
309    except TypeError:
310      # If the object is unhashable (e.g. a subclass of `AbstractBaseClass`
311      # that has not overridden `__hash__`), a `TypeError` will be thrown.
312      # We'll just continue on without shared object support.
313      pass
314    return shared_object_config
315
316  def __exit__(self, *args, **kwargs):
317    if not getattr(self, '_passthrough', False):
318      global SHARED_OBJECT_SAVING
319      SHARED_OBJECT_SAVING.scope = None
320
321
322def serialize_keras_class_and_config(
323    cls_name, cls_config, obj=None, shared_object_id=None):
324  """Returns the serialization of the class with the given config."""
325  base_config = {'class_name': cls_name, 'config': cls_config}
326
327  # We call `serialize_keras_class_and_config` for some branches of the load
328  # path. In that case, we may already have a shared object ID we'd like to
329  # retain.
330  if shared_object_id is not None:
331    base_config[SHARED_OBJECT_KEY] = shared_object_id
332
333  # If we have an active `SharedObjectSavingScope`, check whether we've already
334  # serialized this config. If so, just use that config. This will store an
335  # extra ID field in the config, allowing us to re-create the shared object
336  # relationship at load time.
337  if _shared_object_saving_scope() is not None and obj is not None:
338    shared_object_config = _shared_object_saving_scope().get_config(obj)
339    if shared_object_config is None:
340      return _shared_object_saving_scope().create_config(base_config, obj)
341    return shared_object_config
342
343  return base_config
344
345
346@keras_export('keras.utils.register_keras_serializable')
347def register_keras_serializable(package='Custom', name=None):
348  """Registers an object with the Keras serialization framework.
349
350  This decorator injects the decorated class or function into the Keras custom
351  object dictionary, so that it can be serialized and deserialized without
352  needing an entry in the user-provided custom object dict. It also injects a
353  function that Keras will call to get the object's serializable string key.
354
355  Note that to be serialized and deserialized, classes must implement the
356  `get_config()` method. Functions do not have this requirement.
357
358  The object will be registered under the key 'package>name' where `name`,
359  defaults to the object name if not passed.
360
361  Args:
362    package: The package that this class belongs to.
363    name: The name to serialize this class under in this package. If None, the
364      class' name will be used.
365
366  Returns:
367    A decorator that registers the decorated class with the passed names.
368  """
369
370  def decorator(arg):
371    """Registers a class with the Keras serialization framework."""
372    class_name = name if name is not None else arg.__name__
373    registered_name = package + '>' + class_name
374
375    if tf_inspect.isclass(arg) and not hasattr(arg, 'get_config'):
376      raise ValueError(
377          'Cannot register a class that does not have a get_config() method.')
378
379    if registered_name in _GLOBAL_CUSTOM_OBJECTS:
380      raise ValueError(
381          '%s has already been registered to %s' %
382          (registered_name, _GLOBAL_CUSTOM_OBJECTS[registered_name]))
383
384    if arg in _GLOBAL_CUSTOM_NAMES:
385      raise ValueError('%s has already been registered to %s' %
386                       (arg, _GLOBAL_CUSTOM_NAMES[arg]))
387    _GLOBAL_CUSTOM_OBJECTS[registered_name] = arg
388    _GLOBAL_CUSTOM_NAMES[arg] = registered_name
389
390    return arg
391
392  return decorator
393
394
395@keras_export('keras.utils.get_registered_name')
396def get_registered_name(obj):
397  """Returns the name registered to an object within the Keras framework.
398
399  This function is part of the Keras serialization and deserialization
400  framework. It maps objects to the string names associated with those objects
401  for serialization/deserialization.
402
403  Args:
404    obj: The object to look up.
405
406  Returns:
407    The name associated with the object, or the default Python name if the
408      object is not registered.
409  """
410  if obj in _GLOBAL_CUSTOM_NAMES:
411    return _GLOBAL_CUSTOM_NAMES[obj]
412  else:
413    return obj.__name__
414
415
416@tf_contextlib.contextmanager
417def skip_failed_serialization():
418  global _SKIP_FAILED_SERIALIZATION
419  prev = _SKIP_FAILED_SERIALIZATION
420  try:
421    _SKIP_FAILED_SERIALIZATION = True
422    yield
423  finally:
424    _SKIP_FAILED_SERIALIZATION = prev
425
426
427@keras_export('keras.utils.get_registered_object')
428def get_registered_object(name, custom_objects=None, module_objects=None):
429  """Returns the class associated with `name` if it is registered with Keras.
430
431  This function is part of the Keras serialization and deserialization
432  framework. It maps strings to the objects associated with them for
433  serialization/deserialization.
434
435  Example:
436  ```
437  def from_config(cls, config, custom_objects=None):
438    if 'my_custom_object_name' in config:
439      config['hidden_cls'] = tf.keras.utils.get_registered_object(
440          config['my_custom_object_name'], custom_objects=custom_objects)
441  ```
442
443  Args:
444    name: The name to look up.
445    custom_objects: A dictionary of custom objects to look the name up in.
446      Generally, custom_objects is provided by the user.
447    module_objects: A dictionary of custom objects to look the name up in.
448      Generally, module_objects is provided by midlevel library implementers.
449
450  Returns:
451    An instantiable class associated with 'name', or None if no such class
452      exists.
453  """
454  if name in _GLOBAL_CUSTOM_OBJECTS:
455    return _GLOBAL_CUSTOM_OBJECTS[name]
456  elif custom_objects and name in custom_objects:
457    return custom_objects[name]
458  elif module_objects and name in module_objects:
459    return module_objects[name]
460  return None
461
462
463@keras_export('keras.utils.serialize_keras_object')
464def serialize_keras_object(instance):
465  """Serialize a Keras object into a JSON-compatible representation.
466
467  Calls to `serialize_keras_object` while underneath the
468  `SharedObjectSavingScope` context manager will cause any objects re-used
469  across multiple layers to be saved with a special shared object ID. This
470  allows the network to be re-created properly during deserialization.
471
472  Args:
473    instance: The object to serialize.
474
475  Returns:
476    A dict-like, JSON-compatible representation of the object's config.
477  """
478  _, instance = tf_decorator.unwrap(instance)
479  if instance is None:
480    return None
481
482  if hasattr(instance, 'get_config'):
483    name = get_registered_name(instance.__class__)
484    try:
485      config = instance.get_config()
486    except NotImplementedError as e:
487      if _SKIP_FAILED_SERIALIZATION:
488        return serialize_keras_class_and_config(
489            name, {_LAYER_UNDEFINED_CONFIG_KEY: True})
490      raise e
491    serialization_config = {}
492    for key, item in config.items():
493      if isinstance(item, six.string_types):
494        serialization_config[key] = item
495        continue
496
497      # Any object of a different type needs to be converted to string or dict
498      # for serialization (e.g. custom functions, custom classes)
499      try:
500        serialized_item = serialize_keras_object(item)
501        if isinstance(serialized_item, dict) and not isinstance(item, dict):
502          serialized_item['__passive_serialization__'] = True
503        serialization_config[key] = serialized_item
504      except ValueError:
505        serialization_config[key] = item
506
507    name = get_registered_name(instance.__class__)
508    return serialize_keras_class_and_config(
509        name, serialization_config, instance)
510  if hasattr(instance, '__name__'):
511    return get_registered_name(instance)
512  raise ValueError('Cannot serialize', instance)
513
514
515def get_custom_objects_by_name(item, custom_objects=None):
516  """Returns the item if it is in either local or global custom objects."""
517  if item in _GLOBAL_CUSTOM_OBJECTS:
518    return _GLOBAL_CUSTOM_OBJECTS[item]
519  elif custom_objects and item in custom_objects:
520    return custom_objects[item]
521  return None
522
523
524def class_and_config_for_serialized_keras_object(
525    config,
526    module_objects=None,
527    custom_objects=None,
528    printable_module_name='object'):
529  """Returns the class name and config for a serialized keras object."""
530  if (not isinstance(config, dict)
531      or 'class_name' not in config
532      or 'config' not in config):
533    raise ValueError('Improper config format: ' + str(config))
534
535  class_name = config['class_name']
536  cls = get_registered_object(class_name, custom_objects, module_objects)
537  if cls is None:
538    raise ValueError(
539        'Unknown {}: {}. Please ensure this object is '
540        'passed to the `custom_objects` argument. See '
541        'https://www.tensorflow.org/guide/keras/save_and_serialize'
542        '#registering_the_custom_object for details.'
543        .format(printable_module_name, class_name))
544
545  cls_config = config['config']
546  # Check if `cls_config` is a list. If it is a list, return the class and the
547  # associated class configs for recursively deserialization. This case will
548  # happen on the old version of sequential model (e.g. `keras_version` ==
549  # "2.0.6"), which is serialized in a different structure, for example
550  # "{'class_name': 'Sequential',
551  #   'config': [{'class_name': 'Embedding', 'config': ...}, {}, ...]}".
552  if isinstance(cls_config, list):
553    return (cls, cls_config)
554
555  deserialized_objects = {}
556  for key, item in cls_config.items():
557    if isinstance(item, dict) and '__passive_serialization__' in item:
558      deserialized_objects[key] = deserialize_keras_object(
559          item,
560          module_objects=module_objects,
561          custom_objects=custom_objects,
562          printable_module_name='config_item')
563    # TODO(momernick): Should this also have 'module_objects'?
564    elif (isinstance(item, six.string_types) and
565          tf_inspect.isfunction(get_registered_object(item, custom_objects))):
566      # Handle custom functions here. When saving functions, we only save the
567      # function's name as a string. If we find a matching string in the custom
568      # objects during deserialization, we convert the string back to the
569      # original function.
570      # Note that a potential issue is that a string field could have a naming
571      # conflict with a custom function name, but this should be a rare case.
572      # This issue does not occur if a string field has a naming conflict with
573      # a custom object, since the config of an object will always be a dict.
574      deserialized_objects[key] = get_registered_object(item, custom_objects)
575  for key, item in deserialized_objects.items():
576    cls_config[key] = deserialized_objects[key]
577
578  return (cls, cls_config)
579
580
581@keras_export('keras.utils.deserialize_keras_object')
582def deserialize_keras_object(identifier,
583                             module_objects=None,
584                             custom_objects=None,
585                             printable_module_name='object'):
586  """Turns the serialized form of a Keras object back into an actual object.
587
588  Calls to `deserialize_keras_object` while underneath the
589  `SharedObjectLoadingScope` context manager will cause any already-seen shared
590  objects to be returned as-is rather than creating a new object.
591
592  Args:
593    identifier: the serialized form of the object.
594    module_objects: A dictionary of custom objects to look the name up in.
595      Generally, module_objects is provided by midlevel library implementers.
596    custom_objects: A dictionary of custom objects to look the name up in.
597      Generally, custom_objects is provided by the user.
598    printable_module_name: A human-readable string representing the type of the
599      object. Printed in case of exception.
600
601  Returns:
602    The deserialized object.
603  """
604  if identifier is None:
605    return None
606
607  if isinstance(identifier, dict):
608    # In this case we are dealing with a Keras config dictionary.
609    config = identifier
610    (cls, cls_config) = class_and_config_for_serialized_keras_object(
611        config, module_objects, custom_objects, printable_module_name)
612
613    # If this object has already been loaded (i.e. it's shared between multiple
614    # objects), return the already-loaded object.
615    shared_object_id = config.get(SHARED_OBJECT_KEY)
616    shared_object = _shared_object_loading_scope().get(shared_object_id)  # pylint: disable=assignment-from-none
617    if shared_object is not None:
618      return shared_object
619
620    if hasattr(cls, 'from_config'):
621      arg_spec = tf_inspect.getfullargspec(cls.from_config)
622      custom_objects = custom_objects or {}
623
624      if 'custom_objects' in arg_spec.args:
625        deserialized_obj = cls.from_config(
626            cls_config,
627            custom_objects=dict(
628                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
629                list(custom_objects.items())))
630      else:
631        with CustomObjectScope(custom_objects):
632          deserialized_obj = cls.from_config(cls_config)
633    else:
634      # Then `cls` may be a function returning a class.
635      # in this case by convention `config` holds
636      # the kwargs of the function.
637      custom_objects = custom_objects or {}
638      with CustomObjectScope(custom_objects):
639        deserialized_obj = cls(**cls_config)
640
641    # Add object to shared objects, in case we find it referenced again.
642    _shared_object_loading_scope().set(shared_object_id, deserialized_obj)
643
644    return deserialized_obj
645
646  elif isinstance(identifier, six.string_types):
647    object_name = identifier
648    if custom_objects and object_name in custom_objects:
649      obj = custom_objects.get(object_name)
650    elif object_name in _GLOBAL_CUSTOM_OBJECTS:
651      obj = _GLOBAL_CUSTOM_OBJECTS[object_name]
652    else:
653      obj = module_objects.get(object_name)
654      if obj is None:
655        raise ValueError(
656            'Unknown {}: {}. Please ensure this object is '
657            'passed to the `custom_objects` argument. See '
658            'https://www.tensorflow.org/guide/keras/save_and_serialize'
659            '#registering_the_custom_object for details.'
660            .format(printable_module_name, object_name))
661
662    # Classes passed by name are instantiated with no args, functions are
663    # returned as-is.
664    if tf_inspect.isclass(obj):
665      return obj()
666    return obj
667  elif tf_inspect.isfunction(identifier):
668    # If a function has already been deserialized, return as is.
669    return identifier
670  else:
671    raise ValueError('Could not interpret serialized %s: %s' %
672                     (printable_module_name, identifier))
673
674
675def func_dump(func):
676  """Serializes a user defined function.
677
678  Args:
679      func: the function to serialize.
680
681  Returns:
682      A tuple `(code, defaults, closure)`.
683  """
684  if os.name == 'nt':
685    raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
686    code = codecs.encode(raw_code, 'base64').decode('ascii')
687  else:
688    raw_code = marshal.dumps(func.__code__)
689    code = codecs.encode(raw_code, 'base64').decode('ascii')
690  defaults = func.__defaults__
691  if func.__closure__:
692    closure = tuple(c.cell_contents for c in func.__closure__)
693  else:
694    closure = None
695  return code, defaults, closure
696
697
698def func_load(code, defaults=None, closure=None, globs=None):
699  """Deserializes a user defined function.
700
701  Args:
702      code: bytecode of the function.
703      defaults: defaults of the function.
704      closure: closure of the function.
705      globs: dictionary of global objects.
706
707  Returns:
708      A function object.
709  """
710  if isinstance(code, (tuple, list)):  # unpack previous dump
711    code, defaults, closure = code
712    if isinstance(defaults, list):
713      defaults = tuple(defaults)
714
715  def ensure_value_to_cell(value):
716    """Ensures that a value is converted to a python cell object.
717
718    Args:
719        value: Any value that needs to be casted to the cell type
720
721    Returns:
722        A value wrapped as a cell object (see function "func_load")
723    """
724
725    def dummy_fn():
726      # pylint: disable=pointless-statement
727      value  # just access it so it gets captured in .__closure__
728
729    cell_value = dummy_fn.__closure__[0]
730    if not isinstance(value, type(cell_value)):
731      return cell_value
732    return value
733
734  if closure is not None:
735    closure = tuple(ensure_value_to_cell(_) for _ in closure)
736  try:
737    raw_code = codecs.decode(code.encode('ascii'), 'base64')
738  except (UnicodeEncodeError, binascii.Error):
739    raw_code = code.encode('raw_unicode_escape')
740  code = marshal.loads(raw_code)
741  if globs is None:
742    globs = globals()
743  return python_types.FunctionType(
744      code, globs, name=code.co_name, argdefs=defaults, closure=closure)
745
746
747def has_arg(fn, name, accept_all=False):
748  """Checks if a callable accepts a given keyword argument.
749
750  Args:
751      fn: Callable to inspect.
752      name: Check if `fn` can be called with `name` as a keyword argument.
753      accept_all: What to return if there is no parameter called `name` but the
754        function accepts a `**kwargs` argument.
755
756  Returns:
757      bool, whether `fn` accepts a `name` keyword argument.
758  """
759  arg_spec = tf_inspect.getfullargspec(fn)
760  if accept_all and arg_spec.varkw is not None:
761    return True
762  return name in arg_spec.args or name in arg_spec.kwonlyargs
763
764
765@keras_export('keras.utils.Progbar')
766class Progbar(object):
767  """Displays a progress bar.
768
769  Args:
770      target: Total number of steps expected, None if unknown.
771      width: Progress bar width on screen.
772      verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
773      stateful_metrics: Iterable of string names of metrics that should *not* be
774        averaged over time. Metrics in this list will be displayed as-is. All
775        others will be averaged by the progbar before display.
776      interval: Minimum visual progress update interval (in seconds).
777      unit_name: Display name for step counts (usually "step" or "sample").
778  """
779
780  def __init__(self,
781               target,
782               width=30,
783               verbose=1,
784               interval=0.05,
785               stateful_metrics=None,
786               unit_name='step'):
787    self.target = target
788    self.width = width
789    self.verbose = verbose
790    self.interval = interval
791    self.unit_name = unit_name
792    if stateful_metrics:
793      self.stateful_metrics = set(stateful_metrics)
794    else:
795      self.stateful_metrics = set()
796
797    self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
798                              sys.stdout.isatty()) or
799                             'ipykernel' in sys.modules or
800                             'posix' in sys.modules or
801                             'PYCHARM_HOSTED' in os.environ)
802    self._total_width = 0
803    self._seen_so_far = 0
804    # We use a dict + list to avoid garbage collection
805    # issues found in OrderedDict
806    self._values = {}
807    self._values_order = []
808    self._start = time.time()
809    self._last_update = 0
810
811    self._time_after_first_step = None
812
813  def update(self, current, values=None, finalize=None):
814    """Updates the progress bar.
815
816    Args:
817        current: Index of current step.
818        values: List of tuples: `(name, value_for_last_step)`. If `name` is in
819          `stateful_metrics`, `value_for_last_step` will be displayed as-is.
820          Else, an average of the metric over time will be displayed.
821        finalize: Whether this is the last update for the progress bar. If
822          `None`, defaults to `current >= self.target`.
823    """
824    if finalize is None:
825      if self.target is None:
826        finalize = False
827      else:
828        finalize = current >= self.target
829
830    values = values or []
831    for k, v in values:
832      if k not in self._values_order:
833        self._values_order.append(k)
834      if k not in self.stateful_metrics:
835        # In the case that progress bar doesn't have a target value in the first
836        # epoch, both on_batch_end and on_epoch_end will be called, which will
837        # cause 'current' and 'self._seen_so_far' to have the same value. Force
838        # the minimal value to 1 here, otherwise stateful_metric will be 0s.
839        value_base = max(current - self._seen_so_far, 1)
840        if k not in self._values:
841          self._values[k] = [v * value_base, value_base]
842        else:
843          self._values[k][0] += v * value_base
844          self._values[k][1] += value_base
845      else:
846        # Stateful metrics output a numeric value. This representation
847        # means "take an average from a single value" but keeps the
848        # numeric formatting.
849        self._values[k] = [v, 1]
850    self._seen_so_far = current
851
852    now = time.time()
853    info = ' - %.0fs' % (now - self._start)
854    if self.verbose == 1:
855      if now - self._last_update < self.interval and not finalize:
856        return
857
858      prev_total_width = self._total_width
859      if self._dynamic_display:
860        sys.stdout.write('\b' * prev_total_width)
861        sys.stdout.write('\r')
862      else:
863        sys.stdout.write('\n')
864
865      if self.target is not None:
866        numdigits = int(np.log10(self.target)) + 1
867        bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
868        prog = float(current) / self.target
869        prog_width = int(self.width * prog)
870        if prog_width > 0:
871          bar += ('=' * (prog_width - 1))
872          if current < self.target:
873            bar += '>'
874          else:
875            bar += '='
876        bar += ('.' * (self.width - prog_width))
877        bar += ']'
878      else:
879        bar = '%7d/Unknown' % current
880
881      self._total_width = len(bar)
882      sys.stdout.write(bar)
883
884      time_per_unit = self._estimate_step_duration(current, now)
885
886      if self.target is None or finalize:
887        if time_per_unit >= 1 or time_per_unit == 0:
888          info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
889        elif time_per_unit >= 1e-3:
890          info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
891        else:
892          info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
893      else:
894        eta = time_per_unit * (self.target - current)
895        if eta > 3600:
896          eta_format = '%d:%02d:%02d' % (eta // 3600,
897                                         (eta % 3600) // 60, eta % 60)
898        elif eta > 60:
899          eta_format = '%d:%02d' % (eta // 60, eta % 60)
900        else:
901          eta_format = '%ds' % eta
902
903        info = ' - ETA: %s' % eta_format
904
905      for k in self._values_order:
906        info += ' - %s:' % k
907        if isinstance(self._values[k], list):
908          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
909          if abs(avg) > 1e-3:
910            info += ' %.4f' % avg
911          else:
912            info += ' %.4e' % avg
913        else:
914          info += ' %s' % self._values[k]
915
916      self._total_width += len(info)
917      if prev_total_width > self._total_width:
918        info += (' ' * (prev_total_width - self._total_width))
919
920      if finalize:
921        info += '\n'
922
923      sys.stdout.write(info)
924      sys.stdout.flush()
925
926    elif self.verbose == 2:
927      if finalize:
928        numdigits = int(np.log10(self.target)) + 1
929        count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
930        info = count + info
931        for k in self._values_order:
932          info += ' - %s:' % k
933          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
934          if avg > 1e-3:
935            info += ' %.4f' % avg
936          else:
937            info += ' %.4e' % avg
938        info += '\n'
939
940        sys.stdout.write(info)
941        sys.stdout.flush()
942
943    self._last_update = now
944
945  def add(self, n, values=None):
946    self.update(self._seen_so_far + n, values)
947
948  def _estimate_step_duration(self, current, now):
949    """Estimate the duration of a single step.
950
951    Given the step number `current` and the corresponding time `now`
952    this function returns an estimate for how long a single step
953    takes. If this is called before one step has been completed
954    (i.e. `current == 0`) then zero is given as an estimate. The duration
955    estimate ignores the duration of the (assumed to be non-representative)
956    first step for estimates when more steps are available (i.e. `current>1`).
957    Args:
958      current: Index of current step.
959      now: The current time.
960    Returns: Estimate of the duration of a single step.
961    """
962    if current:
963      # there are a few special scenarios here:
964      # 1) somebody is calling the progress bar without ever supplying step 1
965      # 2) somebody is calling the progress bar and supplies step one mulitple
966      #    times, e.g. as part of a finalizing call
967      # in these cases, we just fall back to the simple calculation
968      if self._time_after_first_step is not None and current > 1:
969        time_per_unit = (now - self._time_after_first_step) / (current - 1)
970      else:
971        time_per_unit = (now - self._start) / current
972
973      if current == 1:
974        self._time_after_first_step = now
975      return time_per_unit
976    else:
977      return 0
978
979
980def make_batches(size, batch_size):
981  """Returns a list of batch indices (tuples of indices).
982
983  Args:
984      size: Integer, total size of the data to slice into batches.
985      batch_size: Integer, batch size.
986
987  Returns:
988      A list of tuples of array indices.
989  """
990  num_batches = int(np.ceil(size / float(batch_size)))
991  return [(i * batch_size, min(size, (i + 1) * batch_size))
992          for i in range(0, num_batches)]
993
994
995def slice_arrays(arrays, start=None, stop=None):
996  """Slice an array or list of arrays.
997
998  This takes an array-like, or a list of
999  array-likes, and outputs:
1000      - arrays[start:stop] if `arrays` is an array-like
1001      - [x[start:stop] for x in arrays] if `arrays` is a list
1002
1003  Can also work on list/array of indices: `slice_arrays(x, indices)`
1004
1005  Args:
1006      arrays: Single array or list of arrays.
1007      start: can be an integer index (start index) or a list/array of indices
1008      stop: integer (stop index); should be None if `start` was a list.
1009
1010  Returns:
1011      A slice of the array(s).
1012
1013  Raises:
1014      ValueError: If the value of start is a list and stop is not None.
1015  """
1016  if arrays is None:
1017    return [None]
1018  if isinstance(start, list) and stop is not None:
1019    raise ValueError('The stop argument has to be None if the value of start '
1020                     'is a list.')
1021  elif isinstance(arrays, list):
1022    if hasattr(start, '__len__'):
1023      # hdf5 datasets only support list objects as indices
1024      if hasattr(start, 'shape'):
1025        start = start.tolist()
1026      return [None if x is None else x[start] for x in arrays]
1027    return [
1028        None if x is None else
1029        None if not hasattr(x, '__getitem__') else x[start:stop] for x in arrays
1030    ]
1031  else:
1032    if hasattr(start, '__len__'):
1033      if hasattr(start, 'shape'):
1034        start = start.tolist()
1035      return arrays[start]
1036    if hasattr(start, '__getitem__'):
1037      return arrays[start:stop]
1038    return [None]
1039
1040
1041def to_list(x):
1042  """Normalizes a list/tensor into a list.
1043
1044  If a tensor is passed, we return
1045  a list of size 1 containing the tensor.
1046
1047  Args:
1048      x: target object to be normalized.
1049
1050  Returns:
1051      A list.
1052  """
1053  if isinstance(x, list):
1054    return x
1055  return [x]
1056
1057
1058def to_snake_case(name):
1059  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
1060  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
1061  # If the class is private the name starts with "_" which is not secure
1062  # for creating scopes. We prefix the name with "private" in this case.
1063  if insecure[0] != '_':
1064    return insecure
1065  return 'private' + insecure
1066
1067
1068def is_all_none(structure):
1069  iterable = nest.flatten(structure)
1070  # We cannot use Python's `any` because the iterable may return Tensors.
1071  for element in iterable:
1072    if element is not None:
1073      return False
1074  return True
1075
1076
1077def check_for_unexpected_keys(name, input_dict, expected_values):
1078  unknown = set(input_dict.keys()).difference(expected_values)
1079  if unknown:
1080    raise ValueError('Unknown entries in {} dictionary: {}. Only expected '
1081                     'following keys: {}'.format(name, list(unknown),
1082                                                 expected_values))
1083
1084
1085def validate_kwargs(kwargs,
1086                    allowed_kwargs,
1087                    error_message='Keyword argument not understood:'):
1088  """Checks that all keyword arguments are in the set of allowed keys."""
1089  for kwarg in kwargs:
1090    if kwarg not in allowed_kwargs:
1091      raise TypeError(error_message, kwarg)
1092
1093
1094def validate_config(config):
1095  """Determines whether config appears to be a valid layer config."""
1096  return isinstance(config, dict) and _LAYER_UNDEFINED_CONFIG_KEY not in config
1097
1098
1099def default(method):
1100  """Decorates a method to detect overrides in subclasses."""
1101  method._is_default = True  # pylint: disable=protected-access
1102  return method
1103
1104
1105def is_default(method):
1106  """Check if a method is decorated with the `default` wrapper."""
1107  return getattr(method, '_is_default', False)
1108
1109
1110def populate_dict_with_module_objects(target_dict, modules, obj_filter):
1111  for module in modules:
1112    for name in dir(module):
1113      obj = getattr(module, name)
1114      if obj_filter(obj):
1115        target_dict[name] = obj
1116
1117
1118class LazyLoader(python_types.ModuleType):
1119  """Lazily import a module, mainly to avoid pulling in large dependencies."""
1120
1121  def __init__(self, local_name, parent_module_globals, name):
1122    self._local_name = local_name
1123    self._parent_module_globals = parent_module_globals
1124    super(LazyLoader, self).__init__(name)
1125
1126  def _load(self):
1127    """Load the module and insert it into the parent's globals."""
1128    # Import the target module and insert it into the parent's namespace
1129    module = importlib.import_module(self.__name__)
1130    self._parent_module_globals[self._local_name] = module
1131    # Update this object's dict so that if someone keeps a reference to the
1132    #   LazyLoader, lookups are efficient (__getattr__ is only called on lookups
1133    #   that fail).
1134    self.__dict__.update(module.__dict__)
1135    return module
1136
1137  def __getattr__(self, item):
1138    module = self._load()
1139    return getattr(module, item)
1140
1141
1142# Aliases
1143
1144custom_object_scope = CustomObjectScope  # pylint: disable=invalid-name
1145