1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions for saving and loading a Keras Model from HDF5 format.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import json
23import os
24
25import numpy as np
26from six.moves import zip  # pylint: disable=redefined-builtin
27
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras import optimizer_v1
30from tensorflow.python.keras.saving import model_config as model_config_lib
31from tensorflow.python.keras.saving import saving_utils
32from tensorflow.python.keras.saving.saved_model import json_utils
33from tensorflow.python.keras.utils.generic_utils import LazyLoader
34from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
35from tensorflow.python.ops import variables as variables_module
36from tensorflow.python.platform import gfile
37from tensorflow.python.platform import tf_logging as logging
38
39
40# pylint: disable=g-import-not-at-top
41try:
42  import h5py
43  HDF5_OBJECT_HEADER_LIMIT = 64512
44except ImportError:
45  h5py = None
46# pylint: enable=g-import-not-at-top
47
48# TODO(b/134426265): Switch back to single-quotes to match the rest of the file
49# once the issue with copybara is fixed.
50# pylint:disable=g-inconsistent-quotes
51sequential_lib = LazyLoader(
52    "sequential_lib", globals(),
53    "tensorflow.python.keras.engine.sequential")
54# pylint:enable=g-inconsistent-quotes
55
56
57def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True):
58  """Saves a model to a HDF5 file.
59
60  The saved model contains:
61      - the model's configuration (topology)
62      - the model's weights
63      - the model's optimizer's state (if any)
64
65  Thus the saved model can be reinstantiated in
66  the exact same state, without any of the code
67  used for model definition or training.
68
69  Args:
70      model: Keras model instance to be saved.
71      filepath: One of the following:
72          - String, path where to save the model
73          - `h5py.File` object where to save the model
74      overwrite: Whether we should overwrite any existing
75          model at the target location, or instead
76          ask the user with a manual prompt.
77      include_optimizer: If True, save optimizer's state together.
78
79  Raises:
80      ImportError: if h5py is not available.
81  """
82
83  if h5py is None:
84    raise ImportError('`save_model` requires h5py.')
85
86  # TODO(psv) Add warning when we save models that contain non-serializable
87  # entities like metrics added using `add_metric` and losses added using
88  # `add_loss.`
89  if len(model.weights) != len(model._undeduplicated_weights):
90    logging.warning('Found duplicated `Variable`s in Model\'s `weights`. '
91                    'This is usually caused by `Variable`s being shared by '
92                    'Layers in the Model. These `Variable`s will be treated '
93                    'as separate `Variable`s when the Model is restored. To '
94                    'avoid this, please save with `save_format="tf"`.')
95
96  if not isinstance(filepath, h5py.File):
97    # If file exists and should not be overwritten.
98    if not overwrite and os.path.isfile(filepath):
99      proceed = ask_to_proceed_with_overwrite(filepath)
100      if not proceed:
101        return
102
103    # Try creating dir if not exist
104    dirpath = os.path.dirname(filepath)
105    if not os.path.exists(dirpath):
106      gfile.MakeDirs(dirpath)
107
108    f = h5py.File(filepath, mode='w')
109    opened_new_file = True
110  else:
111    f = filepath
112    opened_new_file = False
113
114  try:
115    model_metadata = saving_utils.model_metadata(model, include_optimizer)
116    for k, v in model_metadata.items():
117      if isinstance(v, (dict, list, tuple)):
118        f.attrs[k] = json.dumps(
119            v, default=json_utils.get_json_type).encode('utf8')
120      else:
121        f.attrs[k] = v
122
123    model_weights_group = f.create_group('model_weights')
124    model_layers = model.layers
125    save_weights_to_hdf5_group(model_weights_group, model_layers)
126
127    # TODO(b/128683857): Add integration tests between tf.keras and external
128    # Keras, to avoid breaking TF.js users.
129    if (include_optimizer and model.optimizer and
130        not isinstance(model.optimizer, optimizer_v1.TFOptimizer)):
131      save_optimizer_weights_to_hdf5_group(f, model.optimizer)
132
133    f.flush()
134  finally:
135    if opened_new_file:
136      f.close()
137
138
139def load_model_from_hdf5(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
140  """Loads a model saved via `save_model_to_hdf5`.
141
142  Args:
143      filepath: One of the following:
144          - String, path to the saved model
145          - `h5py.File` object from which to load the model
146      custom_objects: Optional dictionary mapping names
147          (strings) to custom classes or functions to be
148          considered during deserialization.
149      compile: Boolean, whether to compile the model
150          after loading.
151
152  Returns:
153      A Keras model instance. If an optimizer was found
154      as part of the saved model, the model is already
155      compiled. Otherwise, the model is uncompiled and
156      a warning will be displayed. When `compile` is set
157      to False, the compilation is omitted without any
158      warning.
159
160  Raises:
161      ImportError: if h5py is not available.
162      ValueError: In case of an invalid savefile.
163  """
164  if h5py is None:
165    raise ImportError('`load_model` requires h5py.')
166
167  if not custom_objects:
168    custom_objects = {}
169
170  opened_new_file = not isinstance(filepath, h5py.File)
171  if opened_new_file:
172    f = h5py.File(filepath, mode='r')
173  else:
174    f = filepath
175
176  model = None
177  try:
178    # instantiate model
179    model_config = f.attrs.get('model_config')
180    if model_config is None:
181      raise ValueError('No model found in config file.')
182    if hasattr(model_config, 'decode'):
183      model_config = model_config.decode('utf-8')
184    model_config = json_utils.decode(model_config)
185    model = model_config_lib.model_from_config(model_config,
186                                               custom_objects=custom_objects)
187
188    # set weights
189    load_weights_from_hdf5_group(f['model_weights'], model.layers)
190
191    if compile:
192      # instantiate optimizer
193      training_config = f.attrs.get('training_config')
194      if hasattr(training_config, 'decode'):
195        training_config = training_config.decode('utf-8')
196      if training_config is None:
197        logging.warning('No training configuration found in the save file, so '
198                        'the model was *not* compiled. Compile it manually.')
199        return model
200      training_config = json_utils.decode(training_config)
201
202      # Compile model.
203      model.compile(**saving_utils.compile_args_from_training_config(
204          training_config, custom_objects))
205      saving_utils.try_build_compiled_arguments(model)
206
207      # Set optimizer weights.
208      if 'optimizer_weights' in f:
209        try:
210          model.optimizer._create_all_weights(model.trainable_variables)
211        except (NotImplementedError, AttributeError):
212          logging.warning(
213              'Error when creating the weights of optimizer {}, making it '
214              'impossible to restore the saved optimizer state. As a result, '
215              'your model is starting with a freshly initialized optimizer.')
216
217        optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
218        try:
219          model.optimizer.set_weights(optimizer_weight_values)
220        except ValueError:
221          logging.warning('Error in loading the saved optimizer '
222                          'state. As a result, your model is '
223                          'starting with a freshly initialized '
224                          'optimizer.')
225  finally:
226    if opened_new_file:
227      f.close()
228  return model
229
230
231def preprocess_weights_for_loading(layer,
232                                   weights,
233                                   original_keras_version=None,
234                                   original_backend=None):
235  """Preprocess layer weights between different Keras formats.
236
237  Converts layers weights from Keras 1 format to Keras 2 and also weights of
238  CuDNN layers in Keras 2.
239
240  Args:
241      layer: Layer instance.
242      weights: List of weights values (Numpy arrays).
243      original_keras_version: Keras version for the weights, as a string.
244      original_backend: Keras backend the weights were trained with,
245          as a string.
246
247  Returns:
248      A list of weights values (Numpy arrays).
249  """
250  def convert_nested_bidirectional(weights):
251    """Converts layers nested in `Bidirectional` wrapper.
252
253    This function uses `preprocess_weights_for_loading()` for converting
254    layers.
255
256    Args:
257        weights: List of weights values (Numpy arrays).
258
259    Returns:
260        A list of weights values (Numpy arrays).
261    """
262    num_weights_per_layer = len(weights) // 2
263    forward_weights = preprocess_weights_for_loading(
264        layer.forward_layer, weights[:num_weights_per_layer],
265        original_keras_version, original_backend)
266    backward_weights = preprocess_weights_for_loading(
267        layer.backward_layer, weights[num_weights_per_layer:],
268        original_keras_version, original_backend)
269    return forward_weights + backward_weights
270
271  def convert_nested_time_distributed(weights):
272    """Converts layers nested in `TimeDistributed` wrapper.
273
274    This function uses `preprocess_weights_for_loading()` for converting nested
275    layers.
276
277    Args:
278        weights: List of weights values (Numpy arrays).
279
280    Returns:
281        A list of weights values (Numpy arrays).
282    """
283    return preprocess_weights_for_loading(
284        layer.layer, weights, original_keras_version, original_backend)
285
286  def convert_nested_model(weights):
287    """Converts layers nested in `Model` or `Sequential`.
288
289    This function uses `preprocess_weights_for_loading()` for converting nested
290    layers.
291
292    Args:
293        weights: List of weights values (Numpy arrays).
294
295    Returns:
296        A list of weights values (Numpy arrays).
297    """
298    trainable_weights = weights[:len(layer.trainable_weights)]
299    non_trainable_weights = weights[len(layer.trainable_weights):]
300
301    new_trainable_weights = []
302    new_non_trainable_weights = []
303
304    for sublayer in layer.layers:
305      num_trainable_weights = len(sublayer.trainable_weights)
306      num_non_trainable_weights = len(sublayer.non_trainable_weights)
307      if sublayer.weights:
308        preprocessed = preprocess_weights_for_loading(
309            layer=sublayer,
310            weights=(trainable_weights[:num_trainable_weights] +
311                     non_trainable_weights[:num_non_trainable_weights]),
312            original_keras_version=original_keras_version,
313            original_backend=original_backend)
314        new_trainable_weights.extend(preprocessed[:num_trainable_weights])
315        new_non_trainable_weights.extend(preprocessed[num_trainable_weights:])
316
317        trainable_weights = trainable_weights[num_trainable_weights:]
318        non_trainable_weights = non_trainable_weights[
319            num_non_trainable_weights:]
320
321    return new_trainable_weights + new_non_trainable_weights
322
323  # Convert layers nested in Bidirectional/Model/Sequential.
324  # Both transformation should be ran for both Keras 1->2 conversion
325  # and for conversion of CuDNN layers.
326  if layer.__class__.__name__ == 'Bidirectional':
327    weights = convert_nested_bidirectional(weights)
328  if layer.__class__.__name__ == 'TimeDistributed':
329    weights = convert_nested_time_distributed(weights)
330  elif layer.__class__.__name__ in ['Model', 'Sequential']:
331    weights = convert_nested_model(weights)
332
333  if original_keras_version == '1':
334    if layer.__class__.__name__ == 'TimeDistributed':
335      weights = preprocess_weights_for_loading(
336          layer.layer, weights, original_keras_version, original_backend)
337
338    if layer.__class__.__name__ == 'Conv1D':
339      shape = weights[0].shape
340      # Handle Keras 1.1 format
341      if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
342        # Legacy shape:
343        # (filters, input_dim, filter_length, 1)
344        assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
345                                                           1)
346        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
347      weights[0] = weights[0][:, 0, :, :]
348
349    if layer.__class__.__name__ == 'Conv2D':
350      if layer.data_format == 'channels_first':
351        # old: (filters, stack_size, kernel_rows, kernel_cols)
352        # new: (kernel_rows, kernel_cols, stack_size, filters)
353        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
354
355    if layer.__class__.__name__ == 'Conv2DTranspose':
356      if layer.data_format == 'channels_last':
357        # old: (kernel_rows, kernel_cols, stack_size, filters)
358        # new: (kernel_rows, kernel_cols, filters, stack_size)
359        weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
360      if layer.data_format == 'channels_first':
361        # old: (filters, stack_size, kernel_rows, kernel_cols)
362        # new: (kernel_rows, kernel_cols, filters, stack_size)
363        weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
364
365    if layer.__class__.__name__ == 'Conv3D':
366      if layer.data_format == 'channels_first':
367        # old: (filters, stack_size, ...)
368        # new: (..., stack_size, filters)
369        weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
370
371    if layer.__class__.__name__ == 'GRU':
372      if len(weights) == 9:
373        kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
374        recurrent_kernel = np.concatenate(
375            [weights[1], weights[4], weights[7]], axis=-1)
376        bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
377        weights = [kernel, recurrent_kernel, bias]
378
379    if layer.__class__.__name__ == 'LSTM':
380      if len(weights) == 12:
381        # old: i, c, f, o
382        # new: i, f, c, o
383        kernel = np.concatenate(
384            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
385        recurrent_kernel = np.concatenate(
386            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
387        bias = np.concatenate(
388            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
389        weights = [kernel, recurrent_kernel, bias]
390
391    if layer.__class__.__name__ == 'ConvLSTM2D':
392      if len(weights) == 12:
393        kernel = np.concatenate(
394            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
395        recurrent_kernel = np.concatenate(
396            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
397        bias = np.concatenate(
398            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
399        if layer.data_format == 'channels_first':
400          # old: (filters, stack_size, kernel_rows, kernel_cols)
401          # new: (kernel_rows, kernel_cols, stack_size, filters)
402          kernel = np.transpose(kernel, (2, 3, 1, 0))
403          recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
404        weights = [kernel, recurrent_kernel, bias]
405
406  conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
407  if layer.__class__.__name__ in conv_layers:
408    if K.int_shape(layer.weights[0]) != weights[0].shape:
409      weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
410      if layer.__class__.__name__ == 'ConvLSTM2D':
411        weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
412
413  # convert CuDNN layers
414  return _convert_rnn_weights(layer, weights)
415
416
417def _convert_rnn_weights(layer, weights):
418  """Converts weights for RNN layers between native and CuDNN format.
419
420  Input kernels for each gate are transposed and converted between Fortran
421  and C layout, recurrent kernels are transposed. For LSTM biases are summed/
422  split in half, for GRU biases are reshaped.
423
424  Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
425  and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
426  compatible with `CuDNNGRU`.
427
428  For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
429
430  Args:
431      layer: Target layer instance.
432      weights: List of source weights values (input kernels, recurrent
433          kernels, [biases]) (Numpy arrays).
434
435  Returns:
436      A list of converted weights values (Numpy arrays).
437
438  Raises:
439      ValueError: for incompatible GRU layer/weights or incompatible biases
440  """
441
442  def transform_kernels(kernels, func, n_gates):
443    """Transforms kernel for each gate separately using given function.
444
445    Args:
446        kernels: Stacked array of kernels for individual gates.
447        func: Function applied to kernel of each gate.
448        n_gates: Number of gates (4 for LSTM, 3 for GRU).
449
450    Returns:
451        Stacked array of transformed kernels.
452    """
453    return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
454
455  def transpose_input(from_cudnn):
456    """Makes a function that transforms input kernels from/to CuDNN format.
457
458    It keeps the shape, but changes between the layout (Fortran/C). Eg.:
459
460    ```
461    Keras                 CuDNN
462    [[0, 1, 2],  <--->  [[0, 2, 4],
463     [3, 4, 5]]          [1, 3, 5]]
464    ```
465
466    It can be passed to `transform_kernels()`.
467
468    Args:
469        from_cudnn: `True` if source weights are in CuDNN format, `False`
470            if they're in plain Keras format.
471
472    Returns:
473        Function that converts input kernel to the other format.
474    """
475    order = 'F' if from_cudnn else 'C'
476
477    def transform(kernel):
478      return kernel.T.reshape(kernel.shape, order=order)
479
480    return transform
481
482  target_class = layer.__class__.__name__
483
484  # convert the weights between CuDNNLSTM and LSTM
485  if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3:
486    # determine if we're loading a CuDNNLSTM layer
487    # from the number of bias weights:
488    # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
489    # if there's no bias weight in the file, skip this conversion
490    units = weights[1].shape[0]
491    bias_shape = weights[2].shape
492    n_gates = 4
493
494    if bias_shape == (2 * units * n_gates,):
495      source = 'CuDNNLSTM'
496    elif bias_shape == (units * n_gates,):
497      source = 'LSTM'
498    else:
499      raise ValueError('Invalid bias shape: ' + str(bias_shape))
500
501    def convert_lstm_weights(weights, from_cudnn=True):
502      """Converts the weights between CuDNNLSTM and LSTM.
503
504      Args:
505        weights: Original weights.
506        from_cudnn: Indicates whether original weights are from CuDNN layer.
507
508      Returns:
509        Updated weights compatible with LSTM.
510      """
511
512      # Transpose (and reshape) input and recurrent kernels
513      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
514                                  n_gates)
515      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
516      if from_cudnn:
517        # merge input and recurrent biases into a single set
518        biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
519      else:
520        # Split single set of biases evenly to two sets. The way of
521        # splitting doesn't matter as long as the two sets sum is kept.
522        biases = np.tile(0.5 * weights[2], 2)
523      return [kernels, recurrent_kernels, biases]
524
525    if source != target_class:
526      weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM')
527
528  # convert the weights between CuDNNGRU and GRU(reset_after=True)
529  if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3:
530    # We can determine the source of the weights from the shape of the bias.
531    # If there is no bias we skip the conversion since
532    # CuDNNGRU always has biases.
533
534    units = weights[1].shape[0]
535    bias_shape = weights[2].shape
536    n_gates = 3
537
538    def convert_gru_weights(weights, from_cudnn=True):
539      """Converts the weights between CuDNNGRU and GRU.
540
541      Args:
542        weights: Original weights.
543        from_cudnn: Indicates whether original weights are from CuDNN layer.
544
545      Returns:
546        Updated weights compatible with GRU.
547      """
548
549      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
550                                  n_gates)
551      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
552      biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
553      return [kernels, recurrent_kernels, biases]
554
555    if bias_shape == (2 * units * n_gates,):
556      source = 'CuDNNGRU'
557    elif bias_shape == (2, units * n_gates):
558      source = 'GRU(reset_after=True)'
559    elif bias_shape == (units * n_gates,):
560      source = 'GRU(reset_after=False)'
561    else:
562      raise ValueError('Invalid bias shape: ' + str(bias_shape))
563
564    if target_class == 'CuDNNGRU':
565      target = 'CuDNNGRU'
566    elif layer.reset_after:
567      target = 'GRU(reset_after=True)'
568    else:
569      target = 'GRU(reset_after=False)'
570
571    # only convert between different types
572    if source != target:
573      types = (source, target)
574      if 'GRU(reset_after=False)' in types:
575        raise ValueError('%s is not compatible with %s' % types)
576      if source == 'CuDNNGRU':
577        weights = convert_gru_weights(weights, from_cudnn=True)
578      elif source == 'GRU(reset_after=True)':
579        weights = convert_gru_weights(weights, from_cudnn=False)
580
581  return weights
582
583
584def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
585  """Saves optimizer weights of a optimizer to a HDF5 group.
586
587  Args:
588      hdf5_group: HDF5 group.
589      optimizer: optimizer instance.
590  """
591
592  symbolic_weights = getattr(optimizer, 'weights')
593  if symbolic_weights:
594    weights_group = hdf5_group.create_group('optimizer_weights')
595    weight_names = [str(w.name).encode('utf8') for w in symbolic_weights]
596    save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names)
597    weight_values = K.batch_get_value(symbolic_weights)
598    for name, val in zip(weight_names, weight_values):
599      param_dset = weights_group.create_dataset(
600          name, val.shape, dtype=val.dtype)
601      if not val.shape:
602        # scalar
603        param_dset[()] = val
604      else:
605        param_dset[:] = val
606
607
608def load_optimizer_weights_from_hdf5_group(hdf5_group):
609  """Load optimizer weights from a HDF5 group.
610
611  Args:
612      hdf5_group: A pointer to a HDF5 group.
613
614  Returns:
615      data: List of optimizer weight names.
616  """
617  weights_group = hdf5_group['optimizer_weights']
618  optimizer_weight_names = load_attributes_from_hdf5_group(
619      weights_group, 'weight_names')
620  return [weights_group[weight_name] for weight_name in optimizer_weight_names]
621
622
623def save_weights_to_hdf5_group(f, layers):
624  """Saves the weights of a list of layers to a HDF5 group.
625
626  Args:
627      f: HDF5 group.
628      layers: List of layer instances.
629  """
630  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
631
632  save_attributes_to_hdf5_group(
633      f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
634  f.attrs['backend'] = K.backend().encode('utf8')
635  f.attrs['keras_version'] = str(keras_version).encode('utf8')
636
637  # Sort model layers by layer name to ensure that group names are strictly
638  # growing to avoid prefix issues.
639  for layer in sorted(layers, key=lambda x: x.name):
640    g = f.create_group(layer.name)
641    weights = _legacy_weights(layer)
642    weight_values = K.batch_get_value(weights)
643    weight_names = [w.name.encode('utf8') for w in weights]
644    save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
645    for name, val in zip(weight_names, weight_values):
646      param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
647      if not val.shape:
648        # scalar
649        param_dset[()] = val
650      else:
651        param_dset[:] = val
652
653
654def load_weights_from_hdf5_group(f, layers):
655  """Implements topological (order-based) weight loading.
656
657  Args:
658      f: A pointer to a HDF5 group.
659      layers: a list of target layers.
660
661  Raises:
662      ValueError: in case of mismatch between provided layers
663          and weights file.
664  """
665  if 'keras_version' in f.attrs:
666    original_keras_version = f.attrs['keras_version']
667    if hasattr(original_keras_version, 'decode'):
668      original_keras_version = original_keras_version.decode('utf8')
669  else:
670    original_keras_version = '1'
671  if 'backend' in f.attrs:
672    original_backend = f.attrs['backend']
673    if hasattr(original_backend, 'decode'):
674      original_backend = original_backend.decode('utf8')
675  else:
676    original_backend = None
677
678  filtered_layers = []
679  for layer in layers:
680    weights = _legacy_weights(layer)
681    if weights:
682      filtered_layers.append(layer)
683
684  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
685  filtered_layer_names = []
686  for name in layer_names:
687    g = f[name]
688    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
689    if weight_names:
690      filtered_layer_names.append(name)
691  layer_names = filtered_layer_names
692  if len(layer_names) != len(filtered_layers):
693    raise ValueError('You are trying to load a weight file '
694                     'containing ' + str(len(layer_names)) +
695                     ' layers into a model with ' + str(len(filtered_layers)) +
696                     ' layers.')
697
698  # We batch weight value assignments in a single backend call
699  # which provides a speedup in TensorFlow.
700  weight_value_tuples = []
701  for k, name in enumerate(layer_names):
702    g = f[name]
703    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
704    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
705    layer = filtered_layers[k]
706    symbolic_weights = _legacy_weights(layer)
707    weight_values = preprocess_weights_for_loading(
708        layer, weight_values, original_keras_version, original_backend)
709    if len(weight_values) != len(symbolic_weights):
710      raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
711                       '" in the current model) was found to '
712                       'correspond to layer ' + name + ' in the save file. '
713                       'However the new layer ' + layer.name + ' expects ' +
714                       str(len(symbolic_weights)) +
715                       ' weights, but the saved weights have ' +
716                       str(len(weight_values)) + ' elements.')
717    weight_value_tuples += zip(symbolic_weights, weight_values)
718  K.batch_set_value(weight_value_tuples)
719
720
721def load_weights_from_hdf5_group_by_name(
722    f, layers, skip_mismatch=False):
723  """Implements name-based weight loading.
724
725  (instead of topological weight loading).
726
727  Layers that have no matching name are skipped.
728
729  Args:
730      f: A pointer to a HDF5 group.
731      layers: a list of target layers.
732      skip_mismatch: Boolean, whether to skip loading of layers
733          where there is a mismatch in the number of weights,
734          or a mismatch in the shape of the weights.
735
736  Raises:
737      ValueError: in case of mismatch between provided layers
738          and weights file and skip_match=False.
739  """
740  if 'keras_version' in f.attrs:
741    original_keras_version = f.attrs['keras_version']
742    if hasattr(original_keras_version, 'decode'):
743      original_keras_version = original_keras_version.decode('utf8')
744  else:
745    original_keras_version = '1'
746  if 'backend' in f.attrs:
747    original_backend = f.attrs['backend']
748    if hasattr(original_backend, 'decode'):
749      original_backend = original_backend.decode('utf8')
750  else:
751    original_backend = None
752
753  # New file format.
754  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
755
756  # Reverse index of layer name to list of layers with name.
757  index = {}
758  for layer in layers:
759    if layer.name:
760      index.setdefault(layer.name, []).append(layer)
761
762  # We batch weight value assignments in a single backend call
763  # which provides a speedup in TensorFlow.
764  weight_value_tuples = []
765  for k, name in enumerate(layer_names):
766    g = f[name]
767    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
768    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
769
770    for layer in index.get(name, []):
771      symbolic_weights = _legacy_weights(layer)
772      weight_values = preprocess_weights_for_loading(
773          layer, weight_values, original_keras_version, original_backend)
774      if len(weight_values) != len(symbolic_weights):
775        if skip_mismatch:
776          logging.warning('Skipping loading of weights for '
777                          'layer {}'.format(layer.name) + ' due to mismatch '
778                          'in number of weights ({} vs {}).'.format(
779                              len(symbolic_weights), len(weight_values)))
780          continue
781        raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
782                         '") expects ' + str(len(symbolic_weights)) +
783                         ' weight(s), but the saved weights' + ' have ' +
784                         str(len(weight_values)) + ' element(s).')
785      # Set values.
786      for i in range(len(weight_values)):
787        if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
788          if skip_mismatch:
789            logging.warning('Skipping loading of weights for '
790                            'layer {}'.format(layer.name) + ' due to '
791                            'mismatch in shape ({} vs {}).'.format(
792                                symbolic_weights[i].shape,
793                                weight_values[i].shape))
794            continue
795          raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
796                           '"), weight ' + str(symbolic_weights[i]) +
797                           ' has shape {}'.format(K.int_shape(
798                               symbolic_weights[i])) +
799                           ', but the saved weight has shape ' +
800                           str(weight_values[i].shape) + '.')
801
802        else:
803          weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
804  K.batch_set_value(weight_value_tuples)
805
806
807def save_attributes_to_hdf5_group(group, name, data):
808  """Saves attributes (data) of the specified name into the HDF5 group.
809
810  This method deals with an inherent problem of HDF5 file which is not
811  able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
812
813  Args:
814      group: A pointer to a HDF5 group.
815      name: A name of the attributes to save.
816      data: Attributes data to store.
817
818  Raises:
819    RuntimeError: If any single attribute is too large to be saved.
820  """
821  # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
822  # because in that case even chunking the array would not make the saving
823  # possible.
824  bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
825
826  # Expecting this to never be true.
827  if bad_attributes:
828    raise RuntimeError('The following attributes cannot be saved to HDF5 '
829                       'file because they are larger than %d bytes: %s' %
830                       (HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes)))
831
832  data_npy = np.asarray(data)
833
834  num_chunks = 1
835  chunked_data = np.array_split(data_npy, num_chunks)
836
837  # This will never loop forever thanks to the test above.
838  while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
839    num_chunks += 1
840    chunked_data = np.array_split(data_npy, num_chunks)
841
842  if num_chunks > 1:
843    for chunk_id, chunk_data in enumerate(chunked_data):
844      group.attrs['%s%d' % (name, chunk_id)] = chunk_data
845  else:
846    group.attrs[name] = data
847
848
849def load_attributes_from_hdf5_group(group, name):
850  """Loads attributes of the specified name from the HDF5 group.
851
852  This method deals with an inherent problem
853  of HDF5 file which is not able to store
854  data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
855
856  Args:
857      group: A pointer to a HDF5 group.
858      name: A name of the attributes to load.
859
860  Returns:
861      data: Attributes data.
862  """
863  if name in group.attrs:
864    data = [
865        n.decode('utf8') if hasattr(n, 'decode') else n
866        for n in group.attrs[name]
867    ]
868  else:
869    data = []
870    chunk_id = 0
871    while '%s%d' % (name, chunk_id) in group.attrs:
872      data.extend([
873          n.decode('utf8') if hasattr(n, 'decode') else n
874          for n in group.attrs['%s%d' % (name, chunk_id)]
875      ])
876      chunk_id += 1
877  return data
878
879
880def _legacy_weights(layer):
881  """DO NOT USE.
882
883  For legacy reason, the layer.weights was in the order of
884  [self.trainable_weights + self.non_trainable_weights], and this order was
885  used for preserving the weights in h5 format. The new order of layer.weights
886  are the same as layer.get_weights() which is more intuitive for user. To
887  keep supporting the existing saved h5 file, this method should be used to
888  save/load weights. In future version, we will delete this method and
889  introduce a breaking change for h5 and stay with the new order for weights.
890
891  Args:
892    layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance.
893
894  Returns:
895    A list of variables with the order of trainable_weights, followed by
896      non_trainable_weights.
897  """
898  weights = layer.trainable_weights + layer.non_trainable_weights
899  if any(not isinstance(w, variables_module.Variable) for w in weights):
900    raise NotImplementedError(
901        'Save or restore weights that is not an instance of `tf.Variable` is '
902        'not supported in h5, use `save_format=\'tf\'` instead. Got a model '
903        'or layer {} with weights {}'.format(layer.__class__.__name__, weights))
904  return weights
905