1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Functions for saving and loading a Keras Model from HDF5 format.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import json
23import os
24
25import numpy as np
26from six.moves import zip  # pylint: disable=redefined-builtin
27
28from tensorflow.python.keras import backend as K
29from tensorflow.python.keras import optimizers
30from tensorflow.python.keras.saving import model_config as model_config_lib
31from tensorflow.python.keras.utils import conv_utils
32from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite
33from tensorflow.python.platform import tf_logging as logging
34from tensorflow.python.util import serialization
35from tensorflow.python.util.tf_export import keras_export
36
37# pylint: disable=g-import-not-at-top
38try:
39  import h5py
40  HDF5_OBJECT_HEADER_LIMIT = 64512
41except ImportError:
42  h5py = None
43# pylint: enable=g-import-not-at-top
44
45
46@keras_export('keras.models.save_model')
47def save_model(model, filepath, overwrite=True, include_optimizer=True):
48  """Saves a model to a HDF5 file.
49
50  The saved model contains:
51      - the model's configuration (topology)
52      - the model's weights
53      - the model's optimizer's state (if any)
54
55  Thus the saved model can be reinstantiated in
56  the exact same state, without any of the code
57  used for model definition or training.
58
59  Arguments:
60      model: Keras model instance to be saved.
61      filepath: One of the following:
62          - String, path where to save the model
63          - `h5py.File` object where to save the model
64      overwrite: Whether we should overwrite any existing
65          model at the target location, or instead
66          ask the user with a manual prompt.
67      include_optimizer: If True, save optimizer's state together.
68
69  Raises:
70      ImportError: if h5py is not available.
71  """
72
73  if h5py is None:
74    raise ImportError('`save_model` requires h5py.')
75
76  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
77
78  # TODO(psv) Add warning when we save models that contain non-serializable
79  # entities like metrics added using `add_metric` and losses added using
80  # `add_loss.`
81
82  if not isinstance(filepath, h5py.File):
83    # If file exists and should not be overwritten.
84    if not overwrite and os.path.isfile(filepath):
85      proceed = ask_to_proceed_with_overwrite(filepath)
86      if not proceed:
87        return
88
89    f = h5py.File(filepath, mode='w')
90    opened_new_file = True
91  else:
92    f = filepath
93    opened_new_file = False
94
95  try:
96    f.attrs['keras_version'] = str(keras_version).encode('utf8')
97    f.attrs['backend'] = K.backend().encode('utf8')
98    f.attrs['model_config'] = json.dumps(
99        {
100            'class_name': model.__class__.__name__,
101            'config': model.get_config()
102        },
103        default=serialization.get_json_type).encode('utf8')
104
105    model_weights_group = f.create_group('model_weights')
106    model_layers = model.layers
107    save_weights_to_hdf5_group(model_weights_group, model_layers)
108
109    if include_optimizer and model.optimizer:
110      if isinstance(model.optimizer, optimizers.TFOptimizer):
111        logging.warning(
112            'TensorFlow optimizers do not '
113            'make it possible to access '
114            'optimizer attributes or optimizer state '
115            'after instantiation. '
116            'As a result, we cannot save the optimizer '
117            'as part of the model save file. '
118            'You will have to compile your model again after loading it. '
119            'Prefer using a Keras optimizer instead '
120            '(see keras.io/optimizers).')
121      else:
122        f.attrs['training_config'] = json.dumps(
123            {
124                'optimizer_config': {
125                    'class_name': model.optimizer.__class__.__name__,
126                    'config': model.optimizer.get_config()
127                },
128                'loss': model.loss,
129                'metrics': model._compile_metrics,
130                'weighted_metrics': model._compile_weighted_metrics,
131                'sample_weight_mode': model.sample_weight_mode,
132                'loss_weights': model.loss_weights,
133            },
134            default=serialization.get_json_type).encode('utf8')
135
136        # Save optimizer weights.
137        save_optimizer_weights_to_hdf5_group(f, model.optimizer)
138    f.flush()
139  finally:
140    if opened_new_file:
141      f.close()
142
143
144@keras_export('keras.models.load_model')
145def load_model(filepath, custom_objects=None, compile=True):  # pylint: disable=redefined-builtin
146  """Loads a model saved via `save_model`.
147
148  Arguments:
149      filepath: One of the following:
150          - String, path to the saved model
151          - `h5py.File` object from which to load the model
152      custom_objects: Optional dictionary mapping names
153          (strings) to custom classes or functions to be
154          considered during deserialization.
155      compile: Boolean, whether to compile the model
156          after loading.
157
158  Returns:
159      A Keras model instance. If an optimizer was found
160      as part of the saved model, the model is already
161      compiled. Otherwise, the model is uncompiled and
162      a warning will be displayed. When `compile` is set
163      to False, the compilation is omitted without any
164      warning.
165
166  Raises:
167      ImportError: if h5py is not available.
168      ValueError: In case of an invalid savefile.
169  """
170  if h5py is None:
171    raise ImportError('`load_model` requires h5py.')
172
173  if not custom_objects:
174    custom_objects = {}
175
176  def convert_custom_objects(obj):
177    """Handles custom object lookup.
178
179    Arguments:
180        obj: object, dict, or list.
181
182    Returns:
183        The same structure, where occurrences
184            of a custom object name have been replaced
185            with the custom object.
186    """
187    if isinstance(obj, list):
188      deserialized = []
189      for value in obj:
190        deserialized.append(convert_custom_objects(value))
191      return deserialized
192    if isinstance(obj, dict):
193      deserialized = {}
194      for key, value in obj.items():
195        deserialized[key] = convert_custom_objects(value)
196      return deserialized
197    if obj in custom_objects:
198      return custom_objects[obj]
199    return obj
200
201  opened_new_file = not isinstance(filepath, h5py.File)
202  if opened_new_file:
203    f = h5py.File(filepath, mode='r')
204  else:
205    f = filepath
206
207  model = None
208  try:
209    # instantiate model
210    model_config = f.attrs.get('model_config')
211    if model_config is None:
212      raise ValueError('No model found in config file.')
213    model_config = json.loads(model_config.decode('utf-8'))
214    model = model_config_lib.model_from_config(model_config,
215                                               custom_objects=custom_objects)
216
217    # set weights
218    load_weights_from_hdf5_group(f['model_weights'], model.layers)
219
220    if compile:
221      # instantiate optimizer
222      training_config = f.attrs.get('training_config')
223      if training_config is None:
224        logging.warning('No training configuration found in save file: '
225                        'the model was *not* compiled. Compile it manually.')
226        return model
227      training_config = json.loads(training_config.decode('utf-8'))
228      optimizer_config = training_config['optimizer_config']
229      optimizer = optimizers.deserialize(
230          optimizer_config, custom_objects=custom_objects)
231
232      # Recover loss functions and metrics.
233      loss = convert_custom_objects(training_config['loss'])
234      metrics = convert_custom_objects(training_config['metrics'])
235      weighted_metrics = convert_custom_objects(
236          training_config.get('weighted_metrics', None))
237      sample_weight_mode = training_config['sample_weight_mode']
238      loss_weights = training_config['loss_weights']
239
240      # Compile model.
241      model.compile(
242          optimizer=optimizer,
243          loss=loss,
244          metrics=metrics,
245          weighted_metrics=weighted_metrics,
246          loss_weights=loss_weights,
247          sample_weight_mode=sample_weight_mode)
248
249      # Set optimizer weights.
250      if 'optimizer_weights' in f:
251        # Build train function (to get weight updates).
252        # Models that aren't graph networks must wait until they are called
253        # with data to _make_train_function() and so can't load optimizer
254        # weights.
255        if model._is_graph_network:  # pylint: disable=protected-access
256          model._make_train_function()
257          optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f)
258          try:
259            model.optimizer.set_weights(optimizer_weight_values)
260          except ValueError:
261            logging.warning('Error in loading the saved optimizer '
262                            'state. As a result, your model is '
263                            'starting with a freshly initialized '
264                            'optimizer.')
265        else:
266          logging.warning('Sequential models without an `input_shape` '
267                          'passed to the first layer cannot reload their '
268                          'optimizer state. As a result, your model is'
269                          'starting with a freshly initialized optimizer.')
270
271  finally:
272    if opened_new_file:
273      f.close()
274  return model
275
276
277def preprocess_weights_for_loading(layer,
278                                   weights,
279                                   original_keras_version=None,
280                                   original_backend=None):
281  """Preprocess layer weights between different Keras formats.
282
283  Converts layers weights from Keras 1 format to Keras 2 and also weights of
284  CuDNN layers in Keras 2.
285
286  Arguments:
287      layer: Layer instance.
288      weights: List of weights values (Numpy arrays).
289      original_keras_version: Keras version for the weights, as a string.
290      original_backend: Keras backend the weights were trained with,
291          as a string.
292
293  Returns:
294      A list of weights values (Numpy arrays).
295  """
296  def convert_nested_bidirectional(weights):
297    """Converts layers nested in `Bidirectional` wrapper.
298
299    This function uses `preprocess_weights_for_loading()` for converting
300    layers.
301
302    Arguments:
303        weights: List of weights values (Numpy arrays).
304
305    Returns:
306        A list of weights values (Numpy arrays).
307    """
308    num_weights_per_layer = len(weights) // 2
309    forward_weights = preprocess_weights_for_loading(
310        layer.forward_layer, weights[:num_weights_per_layer],
311        original_keras_version, original_backend)
312    backward_weights = preprocess_weights_for_loading(
313        layer.backward_layer, weights[num_weights_per_layer:],
314        original_keras_version, original_backend)
315    return forward_weights + backward_weights
316
317  def convert_nested_time_distributed(weights):
318    """Converts layers nested in `TimeDistributed` wrapper.
319
320    This function uses `preprocess_weights_for_loading()` for converting nested
321    layers.
322
323    Arguments:
324        weights: List of weights values (Numpy arrays).
325
326    Returns:
327        A list of weights values (Numpy arrays).
328    """
329    return preprocess_weights_for_loading(
330        layer.layer, weights, original_keras_version, original_backend)
331
332  def convert_nested_model(weights):
333    """Converts layers nested in `Model` or `Sequential`.
334
335    This function uses `preprocess_weights_for_loading()` for converting nested
336    layers.
337
338    Arguments:
339        weights: List of weights values (Numpy arrays).
340
341    Returns:
342        A list of weights values (Numpy arrays).
343    """
344    new_weights = []
345    # trainable weights
346    for sublayer in layer.layers:
347      num_weights = len(sublayer.trainable_weights)
348      if num_weights > 0:
349        new_weights.extend(preprocess_weights_for_loading(
350            layer=sublayer,
351            weights=weights[:num_weights],
352            original_keras_version=original_keras_version,
353            original_backend=original_backend))
354        weights = weights[num_weights:]
355
356    # non-trainable weights
357    for sublayer in layer.layers:
358      num_weights = len([l for l in sublayer.weights
359                         if l not in sublayer.trainable_weights])
360      if num_weights > 0:
361        new_weights.extend(preprocess_weights_for_loading(
362            layer=sublayer,
363            weights=weights[:num_weights],
364            original_keras_version=original_keras_version,
365            original_backend=original_backend))
366        weights = weights[num_weights:]
367    return new_weights
368
369  # Convert layers nested in Bidirectional/Model/Sequential.
370  # Both transformation should be ran for both Keras 1->2 conversion
371  # and for conversion of CuDNN layers.
372  if layer.__class__.__name__ == 'Bidirectional':
373    weights = convert_nested_bidirectional(weights)
374  if layer.__class__.__name__ == 'TimeDistributed':
375    weights = convert_nested_time_distributed(weights)
376  elif layer.__class__.__name__ in ['Model', 'Sequential']:
377    weights = convert_nested_model(weights)
378
379  if original_keras_version == '1':
380    if layer.__class__.__name__ == 'TimeDistributed':
381      weights = preprocess_weights_for_loading(
382          layer.layer, weights, original_keras_version, original_backend)
383
384    if layer.__class__.__name__ == 'Conv1D':
385      shape = weights[0].shape
386      # Handle Keras 1.1 format
387      if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters:
388        # Legacy shape:
389        # (filters, input_dim, filter_length, 1)
390        assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0],
391                                                           1)
392        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
393      weights[0] = weights[0][:, 0, :, :]
394
395    if layer.__class__.__name__ == 'Conv2D':
396      if layer.data_format == 'channels_first':
397        # old: (filters, stack_size, kernel_rows, kernel_cols)
398        # new: (kernel_rows, kernel_cols, stack_size, filters)
399        weights[0] = np.transpose(weights[0], (2, 3, 1, 0))
400
401    if layer.__class__.__name__ == 'Conv2DTranspose':
402      if layer.data_format == 'channels_last':
403        # old: (kernel_rows, kernel_cols, stack_size, filters)
404        # new: (kernel_rows, kernel_cols, filters, stack_size)
405        weights[0] = np.transpose(weights[0], (0, 1, 3, 2))
406      if layer.data_format == 'channels_first':
407        # old: (filters, stack_size, kernel_rows, kernel_cols)
408        # new: (kernel_rows, kernel_cols, filters, stack_size)
409        weights[0] = np.transpose(weights[0], (2, 3, 0, 1))
410
411    if layer.__class__.__name__ == 'Conv3D':
412      if layer.data_format == 'channels_first':
413        # old: (filters, stack_size, ...)
414        # new: (..., stack_size, filters)
415        weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0))
416
417    if layer.__class__.__name__ == 'GRU':
418      if len(weights) == 9:
419        kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1)
420        recurrent_kernel = np.concatenate(
421            [weights[1], weights[4], weights[7]], axis=-1)
422        bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1)
423        weights = [kernel, recurrent_kernel, bias]
424
425    if layer.__class__.__name__ == 'LSTM':
426      if len(weights) == 12:
427        # old: i, c, f, o
428        # new: i, f, c, o
429        kernel = np.concatenate(
430            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
431        recurrent_kernel = np.concatenate(
432            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
433        bias = np.concatenate(
434            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
435        weights = [kernel, recurrent_kernel, bias]
436
437    if layer.__class__.__name__ == 'ConvLSTM2D':
438      if len(weights) == 12:
439        kernel = np.concatenate(
440            [weights[0], weights[6], weights[3], weights[9]], axis=-1)
441        recurrent_kernel = np.concatenate(
442            [weights[1], weights[7], weights[4], weights[10]], axis=-1)
443        bias = np.concatenate(
444            [weights[2], weights[8], weights[5], weights[11]], axis=-1)
445        if layer.data_format == 'channels_first':
446          # old: (filters, stack_size, kernel_rows, kernel_cols)
447          # new: (kernel_rows, kernel_cols, stack_size, filters)
448          kernel = np.transpose(kernel, (2, 3, 1, 0))
449          recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0))
450        weights = [kernel, recurrent_kernel, bias]
451
452  conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D']
453  if layer.__class__.__name__ in conv_layers:
454    if original_backend == 'theano':
455      weights[0] = conv_utils.convert_kernel(weights[0])
456      if layer.__class__.__name__ == 'ConvLSTM2D':
457        weights[1] = conv_utils.convert_kernel(weights[1])
458    if K.int_shape(layer.weights[0]) != weights[0].shape:
459      weights[0] = np.transpose(weights[0], (3, 2, 0, 1))
460      if layer.__class__.__name__ == 'ConvLSTM2D':
461        weights[1] = np.transpose(weights[1], (3, 2, 0, 1))
462
463  # convert CuDNN layers
464  return _convert_rnn_weights(layer, weights)
465
466
467def _convert_rnn_weights(layer, weights):
468  """Converts weights for RNN layers between native and CuDNN format.
469
470  Input kernels for each gate are transposed and converted between Fortran
471  and C layout, recurrent kernels are transposed. For LSTM biases are summed/
472  split in half, for GRU biases are reshaped.
473
474  Weights can be converted in both directions between `LSTM` and`CuDNNSLTM`
475  and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not
476  compatible with `CuDNNGRU`.
477
478  For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made.
479
480  Arguments:
481      layer: Target layer instance.
482      weights: List of source weights values (input kernels, recurrent
483          kernels, [biases]) (Numpy arrays).
484
485  Returns:
486      A list of converted weights values (Numpy arrays).
487
488  Raises:
489      ValueError: for incompatible GRU layer/weights or incompatible biases
490  """
491
492  def transform_kernels(kernels, func, n_gates):
493    """Transforms kernel for each gate separately using given function.
494
495    Arguments:
496        kernels: Stacked array of kernels for individual gates.
497        func: Function applied to kernel of each gate.
498        n_gates: Number of gates (4 for LSTM, 3 for GRU).
499
500    Returns:
501        Stacked array of transformed kernels.
502    """
503    return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)])
504
505  def transpose_input(from_cudnn):
506    """Makes a function that transforms input kernels from/to CuDNN format.
507
508    It keeps the shape, but changes between the layout (Fortran/C). Eg.:
509
510    ```
511    Keras                 CuDNN
512    [[0, 1, 2],  <--->  [[0, 2, 4],
513     [3, 4, 5]]          [1, 3, 5]]
514    ```
515
516    It can be passed to `transform_kernels()`.
517
518    Arguments:
519        from_cudnn: `True` if source weights are in CuDNN format, `False`
520            if they're in plain Keras format.
521
522    Returns:
523        Function that converts input kernel to the other format.
524    """
525    order = 'F' if from_cudnn else 'C'
526
527    def transform(kernel):
528      return kernel.T.reshape(kernel.shape, order=order)
529
530    return transform
531
532  target_class = layer.__class__.__name__
533
534  # convert the weights between CuDNNLSTM and LSTM
535  if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3:
536    # determine if we're loading a CuDNNLSTM layer
537    # from the number of bias weights:
538    # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4)
539    # if there's no bias weight in the file, skip this conversion
540    units = weights[1].shape[0]
541    bias_shape = weights[2].shape
542    n_gates = 4
543
544    if bias_shape == (2 * units * n_gates,):
545      source = 'CuDNNLSTM'
546    elif bias_shape == (units * n_gates,):
547      source = 'LSTM'
548    else:
549      raise ValueError('Invalid bias shape: ' + str(bias_shape))
550
551    def convert_lstm_weights(weights, from_cudnn=True):
552      """Converts the weights between CuDNNLSTM and LSTM.
553
554      Arguments:
555        weights: Original weights.
556        from_cudnn: Indicates whether original weights are from CuDNN layer.
557
558      Returns:
559        Updated weights compatible with LSTM.
560      """
561
562      # Transpose (and reshape) input and recurrent kernels
563      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
564                                  n_gates)
565      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
566      if from_cudnn:
567        # merge input and recurrent biases into a single set
568        biases = np.sum(np.split(weights[2], 2, axis=0), axis=0)
569      else:
570        # Split single set of biases evenly to two sets. The way of
571        # splitting doesn't matter as long as the two sets sum is kept.
572        biases = np.tile(0.5 * weights[2], 2)
573      return [kernels, recurrent_kernels, biases]
574
575    if source != target_class:
576      weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM')
577
578  # convert the weights between CuDNNGRU and GRU(reset_after=True)
579  if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3:
580    # We can determine the source of the weights from the shape of the bias.
581    # If there is no bias we skip the conversion since
582    # CuDNNGRU always has biases.
583
584    units = weights[1].shape[0]
585    bias_shape = weights[2].shape
586    n_gates = 3
587
588    def convert_gru_weights(weights, from_cudnn=True):
589      """Converts the weights between CuDNNGRU and GRU.
590
591      Arguments:
592        weights: Original weights.
593        from_cudnn: Indicates whether original weights are from CuDNN layer.
594
595      Returns:
596        Updated weights compatible with GRU.
597      """
598
599      kernels = transform_kernels(weights[0], transpose_input(from_cudnn),
600                                  n_gates)
601      recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates)
602      biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1)
603      return [kernels, recurrent_kernels, biases]
604
605    if bias_shape == (2 * units * n_gates,):
606      source = 'CuDNNGRU'
607    elif bias_shape == (2, units * n_gates):
608      source = 'GRU(reset_after=True)'
609    elif bias_shape == (units * n_gates,):
610      source = 'GRU(reset_after=False)'
611    else:
612      raise ValueError('Invalid bias shape: ' + str(bias_shape))
613
614    if target_class == 'CuDNNGRU':
615      target = 'CuDNNGRU'
616    elif layer.reset_after:
617      target = 'GRU(reset_after=True)'
618    else:
619      target = 'GRU(reset_after=False)'
620
621    # only convert between different types
622    if source != target:
623      types = (source, target)
624      if 'GRU(reset_after=False)' in types:
625        raise ValueError('%s is not compatible with %s' % types)
626      if source == 'CuDNNGRU':
627        weights = convert_gru_weights(weights, from_cudnn=True)
628      elif source == 'GRU(reset_after=True)':
629        weights = convert_gru_weights(weights, from_cudnn=False)
630
631  return weights
632
633
634def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer):
635  """Saves optimizer weights of a optimizer to a HDF5 group.
636
637  Arguments:
638      hdf5_group: HDF5 group.
639      optimizer: optimizer instance.
640  """
641
642  symbolic_weights = getattr(optimizer, 'weights')
643  if symbolic_weights:
644    weights_group = hdf5_group.create_group('optimizer_weights')
645    weight_names = [str(w.name).encode('utf8') for w in symbolic_weights]
646    save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names)
647    weight_values = K.batch_get_value(symbolic_weights)
648    for name, val in zip(weight_names, weight_values):
649      param_dset = weights_group.create_dataset(
650          name, val.shape, dtype=val.dtype)
651      if not val.shape:
652        # scalar
653        param_dset[()] = val
654      else:
655        param_dset[:] = val
656
657
658def load_optimizer_weights_from_hdf5_group(hdf5_group):
659  """Load optimizer weights from a HDF5 group.
660
661  Arguments:
662      hdf5_group: A pointer to a HDF5 group.
663
664  Returns:
665      data: List of optimizer weight names.
666  """
667  weights_group = hdf5_group['optimizer_weights']
668  optimizer_weight_names = load_attributes_from_hdf5_group(
669      weights_group, 'weight_names')
670  return [weights_group[weight_name] for weight_name in optimizer_weight_names]
671
672
673def save_weights_to_hdf5_group(f, layers):
674  """Saves the weights of a list of layers to a HDF5 group.
675
676  Arguments:
677      f: HDF5 group.
678      layers: List of layer instances.
679  """
680  from tensorflow.python.keras import __version__ as keras_version  # pylint: disable=g-import-not-at-top
681
682  save_attributes_to_hdf5_group(
683      f, 'layer_names', [layer.name.encode('utf8') for layer in layers])
684  f.attrs['backend'] = K.backend().encode('utf8')
685  f.attrs['keras_version'] = str(keras_version).encode('utf8')
686
687  for layer in layers:
688    g = f.create_group(layer.name)
689    weight_values = K.batch_get_value(layer.weights)
690    weight_names = [w.name.encode('utf8') for w in layer.weights]
691    save_attributes_to_hdf5_group(g, 'weight_names', weight_names)
692    for name, val in zip(weight_names, weight_values):
693      param_dset = g.create_dataset(name, val.shape, dtype=val.dtype)
694      if not val.shape:
695        # scalar
696        param_dset[()] = val
697      else:
698        param_dset[:] = val
699
700
701def load_weights_from_hdf5_group(f, layers):
702  """Implements topological (order-based) weight loading.
703
704  Arguments:
705      f: A pointer to a HDF5 group.
706      layers: a list of target layers.
707
708  Raises:
709      ValueError: in case of mismatch between provided layers
710          and weights file.
711  """
712  if 'keras_version' in f.attrs:
713    original_keras_version = f.attrs['keras_version'].decode('utf8')
714  else:
715    original_keras_version = '1'
716  if 'backend' in f.attrs:
717    original_backend = f.attrs['backend'].decode('utf8')
718  else:
719    original_backend = None
720
721  filtered_layers = []
722  for layer in layers:
723    weights = layer.weights
724    if weights:
725      filtered_layers.append(layer)
726
727  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
728  filtered_layer_names = []
729  for name in layer_names:
730    g = f[name]
731    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
732    if weight_names:
733      filtered_layer_names.append(name)
734  layer_names = filtered_layer_names
735  if len(layer_names) != len(filtered_layers):
736    raise ValueError('You are trying to load a weight file '
737                     'containing ' + str(len(layer_names)) +
738                     ' layers into a model with ' + str(len(filtered_layers)) +
739                     ' layers.')
740
741  # We batch weight value assignments in a single backend call
742  # which provides a speedup in TensorFlow.
743  weight_value_tuples = []
744  for k, name in enumerate(layer_names):
745    g = f[name]
746    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
747    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
748    layer = filtered_layers[k]
749    symbolic_weights = layer.weights
750    weight_values = preprocess_weights_for_loading(
751        layer, weight_values, original_keras_version, original_backend)
752    if len(weight_values) != len(symbolic_weights):
753      raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
754                       '" in the current model) was found to '
755                       'correspond to layer ' + name + ' in the save file. '
756                       'However the new layer ' + layer.name + ' expects ' +
757                       str(len(symbolic_weights)) +
758                       ' weights, but the saved weights have ' +
759                       str(len(weight_values)) + ' elements.')
760    weight_value_tuples += zip(symbolic_weights, weight_values)
761  K.batch_set_value(weight_value_tuples)
762
763
764def load_weights_from_hdf5_group_by_name(f, layers):
765  """Implements name-based weight loading.
766
767  (instead of topological weight loading).
768
769  Layers that have no matching name are skipped.
770
771  Arguments:
772      f: A pointer to a HDF5 group.
773      layers: a list of target layers.
774
775  Raises:
776      ValueError: in case of mismatch between provided layers
777          and weights file.
778  """
779  if 'keras_version' in f.attrs:
780    original_keras_version = f.attrs['keras_version'].decode('utf8')
781  else:
782    original_keras_version = '1'
783  if 'backend' in f.attrs:
784    original_backend = f.attrs['backend'].decode('utf8')
785  else:
786    original_backend = None
787
788  # New file format.
789  layer_names = load_attributes_from_hdf5_group(f, 'layer_names')
790
791  # Reverse index of layer name to list of layers with name.
792  index = {}
793  for layer in layers:
794    if layer.name:
795      index.setdefault(layer.name, []).append(layer)
796
797  # We batch weight value assignments in a single backend call
798  # which provides a speedup in TensorFlow.
799  weight_value_tuples = []
800  for k, name in enumerate(layer_names):
801    g = f[name]
802    weight_names = load_attributes_from_hdf5_group(g, 'weight_names')
803    weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names]
804
805    for layer in index.get(name, []):
806      symbolic_weights = layer.weights
807      weight_values = preprocess_weights_for_loading(
808          layer, weight_values, original_keras_version, original_backend)
809      if len(weight_values) != len(symbolic_weights):
810        raise ValueError('Layer #' + str(k) + ' (named "' + layer.name +
811                         '") expects ' + str(len(symbolic_weights)) +
812                         ' weight(s), but the saved weights' + ' have ' +
813                         str(len(weight_values)) + ' element(s).')
814      # Set values.
815      for i in range(len(weight_values)):
816        if K.int_shape(symbolic_weights[i]) != weight_values[i].shape:
817          raise ValueError('Layer #' + str(k) +' (named "' + layer.name +
818                           '"), weight ' + str(symbolic_weights[i]) +
819                           ' has shape {}'.format(K.int_shape(
820                               symbolic_weights[i])) +
821                           ', but the saved weight has shape ' +
822                           str(weight_values[i].shape) + '.')
823
824        else:
825          weight_value_tuples.append((symbolic_weights[i], weight_values[i]))
826  K.batch_set_value(weight_value_tuples)
827
828
829def save_attributes_to_hdf5_group(group, name, data):
830  """Saves attributes (data) of the specified name into the HDF5 group.
831
832  This method deals with an inherent problem of HDF5 file which is not
833  able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
834
835  Arguments:
836      group: A pointer to a HDF5 group.
837      name: A name of the attributes to save.
838      data: Attributes data to store.
839
840  Raises:
841    RuntimeError: If any single attribute is too large to be saved.
842  """
843  # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT`
844  # because in that case even chunking the array would not make the saving
845  # possible.
846  bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT]
847
848  # Expecting this to never be true.
849  if bad_attributes:
850    raise RuntimeError('The following attributes cannot be saved to HDF5 '
851                       'file because they are larger than %d bytes: %s' %
852                       (HDF5_OBJECT_HEADER_LIMIT,
853                        ', '.join([x for x in bad_attributes])))
854
855  data_npy = np.asarray(data)
856
857  num_chunks = 1
858  chunked_data = np.array_split(data_npy, num_chunks)
859
860  # This will never loop forever thanks to the test above.
861  while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data):
862    num_chunks += 1
863    chunked_data = np.array_split(data_npy, num_chunks)
864
865  if num_chunks > 1:
866    for chunk_id, chunk_data in enumerate(chunked_data):
867      group.attrs['%s%d' % (name, chunk_id)] = chunk_data
868  else:
869    group.attrs[name] = data
870
871
872def load_attributes_from_hdf5_group(group, name):
873  """Loads attributes of the specified name from the HDF5 group.
874
875  This method deals with an inherent problem
876  of HDF5 file which is not able to store
877  data larger than HDF5_OBJECT_HEADER_LIMIT bytes.
878
879  Arguments:
880      group: A pointer to a HDF5 group.
881      name: A name of the attributes to load.
882
883  Returns:
884      data: Attributes data.
885  """
886  if name in group.attrs:
887    data = [n.decode('utf8') for n in group.attrs[name]]
888  else:
889    data = []
890    chunk_id = 0
891    while '%s%d' % (name, chunk_id) in group.attrs:
892      data.extend(
893          [n.decode('utf8') for n in group.attrs['%s%d' % (name, chunk_id)]])
894      chunk_id += 1
895  return data
896