1# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15"""Python utilities required by Keras."""
16from __future__ import absolute_import
17from __future__ import division
18from __future__ import print_function
19
20import binascii
21import codecs
22import marshal
23import os
24import re
25import sys
26import time
27import types as python_types
28
29import numpy as np
30import six
31
32from tensorflow.python.util import nest
33from tensorflow.python.util import tf_decorator
34from tensorflow.python.util import tf_inspect
35from tensorflow.python.util.tf_export import keras_export
36
37_GLOBAL_CUSTOM_OBJECTS = {}
38
39
40@keras_export('keras.utils.CustomObjectScope')
41class CustomObjectScope(object):
42  """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
43
44  Code within a `with` statement will be able to access custom objects
45  by name. Changes to global custom objects persist
46  within the enclosing `with` statement. At end of the `with` statement,
47  global custom objects are reverted to state
48  at beginning of the `with` statement.
49
50  Example:
51
52  Consider a custom object `MyObject` (e.g. a class):
53
54  ```python
55      with CustomObjectScope({'MyObject':MyObject}):
56          layer = Dense(..., kernel_regularizer='MyObject')
57          # save, load, etc. will recognize custom object by name
58  ```
59  """
60
61  def __init__(self, *args):
62    self.custom_objects = args
63    self.backup = None
64
65  def __enter__(self):
66    self.backup = _GLOBAL_CUSTOM_OBJECTS.copy()
67    for objects in self.custom_objects:
68      _GLOBAL_CUSTOM_OBJECTS.update(objects)
69    return self
70
71  def __exit__(self, *args, **kwargs):
72    _GLOBAL_CUSTOM_OBJECTS.clear()
73    _GLOBAL_CUSTOM_OBJECTS.update(self.backup)
74
75
76@keras_export('keras.utils.custom_object_scope')
77def custom_object_scope(*args):
78  """Provides a scope that changes to `_GLOBAL_CUSTOM_OBJECTS` cannot escape.
79
80  Convenience wrapper for `CustomObjectScope`.
81  Code within a `with` statement will be able to access custom objects
82  by name. Changes to global custom objects persist
83  within the enclosing `with` statement. At end of the `with` statement,
84  global custom objects are reverted to state
85  at beginning of the `with` statement.
86
87  Example:
88
89  Consider a custom object `MyObject`
90
91  ```python
92      with custom_object_scope({'MyObject':MyObject}):
93          layer = Dense(..., kernel_regularizer='MyObject')
94          # save, load, etc. will recognize custom object by name
95  ```
96
97  Arguments:
98      *args: Variable length list of dictionaries of name,
99          class pairs to add to custom objects.
100
101  Returns:
102      Object of type `CustomObjectScope`.
103  """
104  return CustomObjectScope(*args)
105
106
107@keras_export('keras.utils.get_custom_objects')
108def get_custom_objects():
109  """Retrieves a live reference to the global dictionary of custom objects.
110
111  Updating and clearing custom objects using `custom_object_scope`
112  is preferred, but `get_custom_objects` can
113  be used to directly access `_GLOBAL_CUSTOM_OBJECTS`.
114
115  Example:
116
117  ```python
118      get_custom_objects().clear()
119      get_custom_objects()['MyObject'] = MyObject
120  ```
121
122  Returns:
123      Global dictionary of names to classes (`_GLOBAL_CUSTOM_OBJECTS`).
124  """
125  return _GLOBAL_CUSTOM_OBJECTS
126
127
128def serialize_keras_class_and_config(cls_name, cls_config):
129  """Returns the serialization of the class with the given config."""
130  return {'class_name': cls_name, 'config': cls_config}
131
132
133@keras_export('keras.utils.serialize_keras_object')
134def serialize_keras_object(instance):
135  _, instance = tf_decorator.unwrap(instance)
136  if instance is None:
137    return None
138  if hasattr(instance, 'get_config'):
139    return serialize_keras_class_and_config(instance.__class__.__name__,
140                                            instance.get_config())
141  if hasattr(instance, '__name__'):
142    return instance.__name__
143  else:
144    raise ValueError('Cannot serialize', instance)
145
146
147def class_and_config_for_serialized_keras_object(
148    config,
149    module_objects=None,
150    custom_objects=None,
151    printable_module_name='object'):
152  """Returns the class name and config for a serialized keras object."""
153  if (not isinstance(config, dict) or 'class_name' not in config or
154      'config' not in config):
155    raise ValueError('Improper config format: ' + str(config))
156
157  class_name = config['class_name']
158  if custom_objects and class_name in custom_objects:
159    cls = custom_objects[class_name]
160  elif class_name in _GLOBAL_CUSTOM_OBJECTS:
161    cls = _GLOBAL_CUSTOM_OBJECTS[class_name]
162  else:
163    module_objects = module_objects or {}
164    cls = module_objects.get(class_name)
165    if cls is None:
166      raise ValueError('Unknown ' + printable_module_name + ': ' + class_name)
167  return (cls, config['config'])
168
169
170@keras_export('keras.utils.deserialize_keras_object')
171def deserialize_keras_object(identifier,
172                             module_objects=None,
173                             custom_objects=None,
174                             printable_module_name='object'):
175  if identifier is None:
176    return None
177  if isinstance(identifier, dict):
178    # In this case we are dealing with a Keras config dictionary.
179    config = identifier
180    (cls, cls_config) = class_and_config_for_serialized_keras_object(
181        config, module_objects, custom_objects, printable_module_name)
182
183    if hasattr(cls, 'from_config'):
184      arg_spec = tf_inspect.getfullargspec(cls.from_config)
185      custom_objects = custom_objects or {}
186
187      if 'custom_objects' in arg_spec.args:
188        return cls.from_config(
189            cls_config,
190            custom_objects=dict(
191                list(_GLOBAL_CUSTOM_OBJECTS.items()) +
192                list(custom_objects.items())))
193      with CustomObjectScope(custom_objects):
194        return cls.from_config(cls_config)
195    else:
196      # Then `cls` may be a function returning a class.
197      # in this case by convention `config` holds
198      # the kwargs of the function.
199      custom_objects = custom_objects or {}
200      with CustomObjectScope(custom_objects):
201        return cls(**cls_config)
202  elif isinstance(identifier, six.string_types):
203    function_name = identifier
204    if custom_objects and function_name in custom_objects:
205      fn = custom_objects.get(function_name)
206    elif function_name in _GLOBAL_CUSTOM_OBJECTS:
207      fn = _GLOBAL_CUSTOM_OBJECTS[function_name]
208    else:
209      fn = module_objects.get(function_name)
210      if fn is None:
211        raise ValueError('Unknown ' + printable_module_name + ':' +
212                         function_name)
213    return fn
214  else:
215    raise ValueError('Could not interpret serialized ' + printable_module_name +
216                     ': ' + identifier)
217
218
219def func_dump(func):
220  """Serializes a user defined function.
221
222  Arguments:
223      func: the function to serialize.
224
225  Returns:
226      A tuple `(code, defaults, closure)`.
227  """
228  if os.name == 'nt':
229    raw_code = marshal.dumps(func.__code__).replace(b'\\', b'/')
230    code = codecs.encode(raw_code, 'base64').decode('ascii')
231  else:
232    raw_code = marshal.dumps(func.__code__)
233    code = codecs.encode(raw_code, 'base64').decode('ascii')
234  defaults = func.__defaults__
235  if func.__closure__:
236    closure = tuple(c.cell_contents for c in func.__closure__)
237  else:
238    closure = None
239  return code, defaults, closure
240
241
242def func_load(code, defaults=None, closure=None, globs=None):
243  """Deserializes a user defined function.
244
245  Arguments:
246      code: bytecode of the function.
247      defaults: defaults of the function.
248      closure: closure of the function.
249      globs: dictionary of global objects.
250
251  Returns:
252      A function object.
253  """
254  if isinstance(code, (tuple, list)):  # unpack previous dump
255    code, defaults, closure = code
256    if isinstance(defaults, list):
257      defaults = tuple(defaults)
258
259  def ensure_value_to_cell(value):
260    """Ensures that a value is converted to a python cell object.
261
262    Arguments:
263        value: Any value that needs to be casted to the cell type
264
265    Returns:
266        A value wrapped as a cell object (see function "func_load")
267    """
268    def dummy_fn():
269      # pylint: disable=pointless-statement
270      value  # just access it so it gets captured in .__closure__
271
272    cell_value = dummy_fn.__closure__[0]
273    if not isinstance(value, type(cell_value)):
274      return cell_value
275    else:
276      return value
277
278  if closure is not None:
279    closure = tuple(ensure_value_to_cell(_) for _ in closure)
280  try:
281    raw_code = codecs.decode(code.encode('ascii'), 'base64')
282  except (UnicodeEncodeError, binascii.Error):
283    raw_code = code.encode('raw_unicode_escape')
284  code = marshal.loads(raw_code)
285  if globs is None:
286    globs = globals()
287  return python_types.FunctionType(
288      code, globs, name=code.co_name, argdefs=defaults, closure=closure)
289
290
291def has_arg(fn, name, accept_all=False):
292  """Checks if a callable accepts a given keyword argument.
293
294  Arguments:
295      fn: Callable to inspect.
296      name: Check if `fn` can be called with `name` as a keyword argument.
297      accept_all: What to return if there is no parameter called `name`
298                  but the function accepts a `**kwargs` argument.
299
300  Returns:
301      bool, whether `fn` accepts a `name` keyword argument.
302  """
303  arg_spec = tf_inspect.getfullargspec(fn)
304  if accept_all and arg_spec.varkw is not None:
305    return True
306  return name in arg_spec.args
307
308
309@keras_export('keras.utils.Progbar')
310class Progbar(object):
311  """Displays a progress bar.
312
313  Arguments:
314      target: Total number of steps expected, None if unknown.
315      width: Progress bar width on screen.
316      verbose: Verbosity mode, 0 (silent), 1 (verbose), 2 (semi-verbose)
317      stateful_metrics: Iterable of string names of metrics that
318          should *not* be averaged over time. Metrics in this list
319          will be displayed as-is. All others will be averaged
320          by the progbar before display.
321      interval: Minimum visual progress update interval (in seconds).
322      unit_name: Display name for step counts (usually "step" or "sample").
323  """
324
325  def __init__(self, target, width=30, verbose=1, interval=0.05,
326               stateful_metrics=None, unit_name='step'):
327    self.target = target
328    self.width = width
329    self.verbose = verbose
330    self.interval = interval
331    self.unit_name = unit_name
332    if stateful_metrics:
333      self.stateful_metrics = set(stateful_metrics)
334    else:
335      self.stateful_metrics = set()
336
337    self._dynamic_display = ((hasattr(sys.stdout, 'isatty') and
338                              sys.stdout.isatty()) or
339                             'ipykernel' in sys.modules or
340                             'posix' in sys.modules)
341    self._total_width = 0
342    self._seen_so_far = 0
343    # We use a dict + list to avoid garbage collection
344    # issues found in OrderedDict
345    self._values = {}
346    self._values_order = []
347    self._start = time.time()
348    self._last_update = 0
349
350  def update(self, current, values=None):
351    """Updates the progress bar.
352
353    Arguments:
354        current: Index of current step.
355        values: List of tuples:
356            `(name, value_for_last_step)`.
357            If `name` is in `stateful_metrics`,
358            `value_for_last_step` will be displayed as-is.
359            Else, an average of the metric over time will be displayed.
360    """
361    values = values or []
362    for k, v in values:
363      if k not in self._values_order:
364        self._values_order.append(k)
365      if k not in self.stateful_metrics:
366        if k not in self._values:
367          self._values[k] = [v * (current - self._seen_so_far),
368                             current - self._seen_so_far]
369        else:
370          self._values[k][0] += v * (current - self._seen_so_far)
371          self._values[k][1] += (current - self._seen_so_far)
372      else:
373        # Stateful metrics output a numeric value. This representation
374        # means "take an average from a single value" but keeps the
375        # numeric formatting.
376        self._values[k] = [v, 1]
377    self._seen_so_far = current
378
379    now = time.time()
380    info = ' - %.0fs' % (now - self._start)
381    if self.verbose == 1:
382      if (now - self._last_update < self.interval and
383          self.target is not None and current < self.target):
384        return
385
386      prev_total_width = self._total_width
387      if self._dynamic_display:
388        sys.stdout.write('\b' * prev_total_width)
389        sys.stdout.write('\r')
390      else:
391        sys.stdout.write('\n')
392
393      if self.target is not None:
394        numdigits = int(np.log10(self.target)) + 1
395        bar = ('%' + str(numdigits) + 'd/%d [') % (current, self.target)
396        prog = float(current) / self.target
397        prog_width = int(self.width * prog)
398        if prog_width > 0:
399          bar += ('=' * (prog_width - 1))
400          if current < self.target:
401            bar += '>'
402          else:
403            bar += '='
404        bar += ('.' * (self.width - prog_width))
405        bar += ']'
406      else:
407        bar = '%7d/Unknown' % current
408
409      self._total_width = len(bar)
410      sys.stdout.write(bar)
411
412      if current:
413        time_per_unit = (now - self._start) / current
414      else:
415        time_per_unit = 0
416      if self.target is not None and current < self.target:
417        eta = time_per_unit * (self.target - current)
418        if eta > 3600:
419          eta_format = '%d:%02d:%02d' % (eta // 3600,
420                                         (eta % 3600) // 60,
421                                         eta % 60)
422        elif eta > 60:
423          eta_format = '%d:%02d' % (eta // 60, eta % 60)
424        else:
425          eta_format = '%ds' % eta
426
427        info = ' - ETA: %s' % eta_format
428      else:
429        if time_per_unit >= 1 or time_per_unit == 0:
430          info += ' %.0fs/%s' % (time_per_unit, self.unit_name)
431        elif time_per_unit >= 1e-3:
432          info += ' %.0fms/%s' % (time_per_unit * 1e3, self.unit_name)
433        else:
434          info += ' %.0fus/%s' % (time_per_unit * 1e6, self.unit_name)
435
436      for k in self._values_order:
437        info += ' - %s:' % k
438        if isinstance(self._values[k], list):
439          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
440          if abs(avg) > 1e-3:
441            info += ' %.4f' % avg
442          else:
443            info += ' %.4e' % avg
444        else:
445          info += ' %s' % self._values[k]
446
447      self._total_width += len(info)
448      if prev_total_width > self._total_width:
449        info += (' ' * (prev_total_width - self._total_width))
450
451      if self.target is not None and current >= self.target:
452        info += '\n'
453
454      sys.stdout.write(info)
455      sys.stdout.flush()
456
457    elif self.verbose == 2:
458      if self.target is not None and current >= self.target:
459        numdigits = int(np.log10(self.target)) + 1
460        count = ('%' + str(numdigits) + 'd/%d') % (current, self.target)
461        info = count + info
462        for k in self._values_order:
463          info += ' - %s:' % k
464          avg = np.mean(self._values[k][0] / max(1, self._values[k][1]))
465          if avg > 1e-3:
466            info += ' %.4f' % avg
467          else:
468            info += ' %.4e' % avg
469        info += '\n'
470
471        sys.stdout.write(info)
472        sys.stdout.flush()
473
474    self._last_update = now
475
476  def add(self, n, values=None):
477    self.update(self._seen_so_far + n, values)
478
479
480def make_batches(size, batch_size):
481  """Returns a list of batch indices (tuples of indices).
482
483  Arguments:
484      size: Integer, total size of the data to slice into batches.
485      batch_size: Integer, batch size.
486
487  Returns:
488      A list of tuples of array indices.
489  """
490  num_batches = int(np.ceil(size / float(batch_size)))
491  return [(i * batch_size, min(size, (i + 1) * batch_size))
492          for i in range(0, num_batches)]
493
494
495def slice_arrays(arrays, start=None, stop=None):
496  """Slice an array or list of arrays.
497
498  This takes an array-like, or a list of
499  array-likes, and outputs:
500      - arrays[start:stop] if `arrays` is an array-like
501      - [x[start:stop] for x in arrays] if `arrays` is a list
502
503  Can also work on list/array of indices: `slice_arrays(x, indices)`
504
505  Arguments:
506      arrays: Single array or list of arrays.
507      start: can be an integer index (start index)
508          or a list/array of indices
509      stop: integer (stop index); should be None if
510          `start` was a list.
511
512  Returns:
513      A slice of the array(s).
514
515  Raises:
516      ValueError: If the value of start is a list and stop is not None.
517  """
518  if arrays is None:
519    return [None]
520  if isinstance(start, list) and stop is not None:
521    raise ValueError('The stop argument has to be None if the value of start '
522                     'is a list.')
523  elif isinstance(arrays, list):
524    if hasattr(start, '__len__'):
525      # hdf5 datasets only support list objects as indices
526      if hasattr(start, 'shape'):
527        start = start.tolist()
528      return [None if x is None else x[start] for x in arrays]
529    else:
530      return [None if x is None else x[start:stop] for x in arrays]
531  else:
532    if hasattr(start, '__len__'):
533      if hasattr(start, 'shape'):
534        start = start.tolist()
535      return arrays[start]
536    elif hasattr(start, '__getitem__'):
537      return arrays[start:stop]
538    else:
539      return [None]
540
541
542def to_list(x):
543  """Normalizes a list/tensor into a list.
544
545  If a tensor is passed, we return
546  a list of size 1 containing the tensor.
547
548  Arguments:
549      x: target object to be normalized.
550
551  Returns:
552      A list.
553  """
554  if isinstance(x, list):
555    return x
556  return [x]
557
558
559def object_list_uid(object_list):
560  """Creates a single string from object ids."""
561  object_list = nest.flatten(object_list)
562  return ', '.join([str(abs(id(x))) for x in object_list])
563
564
565def to_snake_case(name):
566  intermediate = re.sub('(.)([A-Z][a-z0-9]+)', r'\1_\2', name)
567  insecure = re.sub('([a-z])([A-Z])', r'\1_\2', intermediate).lower()
568  # If the class is private the name starts with "_" which is not secure
569  # for creating scopes. We prefix the name with "private" in this case.
570  if insecure[0] != '_':
571    return insecure
572  return 'private' + insecure
573
574
575def is_all_none(structure):
576  iterable = nest.flatten(structure)
577  # We cannot use Python's `any` because the iterable may return Tensors.
578  for element in iterable:
579    if element is not None:
580      return False
581  return True
582