1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Functions for saving and loading a Keras Model from HDF5 format. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import json 23import os 24 25import numpy as np 26from six.moves import zip # pylint: disable=redefined-builtin 27 28from tensorflow.python.keras import backend as K 29from tensorflow.python.keras import optimizer_v1 30from tensorflow.python.keras.saving import model_config as model_config_lib 31from tensorflow.python.keras.saving import saving_utils 32from tensorflow.python.keras.saving.saved_model import json_utils 33from tensorflow.python.keras.utils.generic_utils import LazyLoader 34from tensorflow.python.keras.utils.io_utils import ask_to_proceed_with_overwrite 35from tensorflow.python.ops import variables as variables_module 36from tensorflow.python.platform import gfile 37from tensorflow.python.platform import tf_logging as logging 38 39 40# pylint: disable=g-import-not-at-top 41try: 42 import h5py 43 HDF5_OBJECT_HEADER_LIMIT = 64512 44except ImportError: 45 h5py = None 46# pylint: enable=g-import-not-at-top 47 48# TODO(b/134426265): Switch back to single-quotes to match the rest of the file 49# once the issue with copybara is fixed. 50# pylint:disable=g-inconsistent-quotes 51sequential_lib = LazyLoader( 52 "sequential_lib", globals(), 53 "tensorflow.python.keras.engine.sequential") 54# pylint:enable=g-inconsistent-quotes 55 56 57def save_model_to_hdf5(model, filepath, overwrite=True, include_optimizer=True): 58 """Saves a model to a HDF5 file. 59 60 The saved model contains: 61 - the model's configuration (topology) 62 - the model's weights 63 - the model's optimizer's state (if any) 64 65 Thus the saved model can be reinstantiated in 66 the exact same state, without any of the code 67 used for model definition or training. 68 69 Args: 70 model: Keras model instance to be saved. 71 filepath: One of the following: 72 - String, path where to save the model 73 - `h5py.File` object where to save the model 74 overwrite: Whether we should overwrite any existing 75 model at the target location, or instead 76 ask the user with a manual prompt. 77 include_optimizer: If True, save optimizer's state together. 78 79 Raises: 80 ImportError: if h5py is not available. 81 """ 82 83 if h5py is None: 84 raise ImportError('`save_model` requires h5py.') 85 86 # TODO(psv) Add warning when we save models that contain non-serializable 87 # entities like metrics added using `add_metric` and losses added using 88 # `add_loss.` 89 if len(model.weights) != len(model._undeduplicated_weights): 90 logging.warning('Found duplicated `Variable`s in Model\'s `weights`. ' 91 'This is usually caused by `Variable`s being shared by ' 92 'Layers in the Model. These `Variable`s will be treated ' 93 'as separate `Variable`s when the Model is restored. To ' 94 'avoid this, please save with `save_format="tf"`.') 95 96 if not isinstance(filepath, h5py.File): 97 # If file exists and should not be overwritten. 98 if not overwrite and os.path.isfile(filepath): 99 proceed = ask_to_proceed_with_overwrite(filepath) 100 if not proceed: 101 return 102 103 # Try creating dir if not exist 104 dirpath = os.path.dirname(filepath) 105 if not os.path.exists(dirpath): 106 gfile.MakeDirs(dirpath) 107 108 f = h5py.File(filepath, mode='w') 109 opened_new_file = True 110 else: 111 f = filepath 112 opened_new_file = False 113 114 try: 115 model_metadata = saving_utils.model_metadata(model, include_optimizer) 116 for k, v in model_metadata.items(): 117 if isinstance(v, (dict, list, tuple)): 118 f.attrs[k] = json.dumps( 119 v, default=json_utils.get_json_type).encode('utf8') 120 else: 121 f.attrs[k] = v 122 123 model_weights_group = f.create_group('model_weights') 124 model_layers = model.layers 125 save_weights_to_hdf5_group(model_weights_group, model_layers) 126 127 # TODO(b/128683857): Add integration tests between tf.keras and external 128 # Keras, to avoid breaking TF.js users. 129 if (include_optimizer and model.optimizer and 130 not isinstance(model.optimizer, optimizer_v1.TFOptimizer)): 131 save_optimizer_weights_to_hdf5_group(f, model.optimizer) 132 133 f.flush() 134 finally: 135 if opened_new_file: 136 f.close() 137 138 139def load_model_from_hdf5(filepath, custom_objects=None, compile=True): # pylint: disable=redefined-builtin 140 """Loads a model saved via `save_model_to_hdf5`. 141 142 Args: 143 filepath: One of the following: 144 - String, path to the saved model 145 - `h5py.File` object from which to load the model 146 custom_objects: Optional dictionary mapping names 147 (strings) to custom classes or functions to be 148 considered during deserialization. 149 compile: Boolean, whether to compile the model 150 after loading. 151 152 Returns: 153 A Keras model instance. If an optimizer was found 154 as part of the saved model, the model is already 155 compiled. Otherwise, the model is uncompiled and 156 a warning will be displayed. When `compile` is set 157 to False, the compilation is omitted without any 158 warning. 159 160 Raises: 161 ImportError: if h5py is not available. 162 ValueError: In case of an invalid savefile. 163 """ 164 if h5py is None: 165 raise ImportError('`load_model` requires h5py.') 166 167 if not custom_objects: 168 custom_objects = {} 169 170 opened_new_file = not isinstance(filepath, h5py.File) 171 if opened_new_file: 172 f = h5py.File(filepath, mode='r') 173 else: 174 f = filepath 175 176 model = None 177 try: 178 # instantiate model 179 model_config = f.attrs.get('model_config') 180 if model_config is None: 181 raise ValueError('No model found in config file.') 182 if hasattr(model_config, 'decode'): 183 model_config = model_config.decode('utf-8') 184 model_config = json_utils.decode(model_config) 185 model = model_config_lib.model_from_config(model_config, 186 custom_objects=custom_objects) 187 188 # set weights 189 load_weights_from_hdf5_group(f['model_weights'], model.layers) 190 191 if compile: 192 # instantiate optimizer 193 training_config = f.attrs.get('training_config') 194 if hasattr(training_config, 'decode'): 195 training_config = training_config.decode('utf-8') 196 if training_config is None: 197 logging.warning('No training configuration found in the save file, so ' 198 'the model was *not* compiled. Compile it manually.') 199 return model 200 training_config = json_utils.decode(training_config) 201 202 # Compile model. 203 model.compile(**saving_utils.compile_args_from_training_config( 204 training_config, custom_objects)) 205 saving_utils.try_build_compiled_arguments(model) 206 207 # Set optimizer weights. 208 if 'optimizer_weights' in f: 209 try: 210 model.optimizer._create_all_weights(model.trainable_variables) 211 except (NotImplementedError, AttributeError): 212 logging.warning( 213 'Error when creating the weights of optimizer {}, making it ' 214 'impossible to restore the saved optimizer state. As a result, ' 215 'your model is starting with a freshly initialized optimizer.') 216 217 optimizer_weight_values = load_optimizer_weights_from_hdf5_group(f) 218 try: 219 model.optimizer.set_weights(optimizer_weight_values) 220 except ValueError: 221 logging.warning('Error in loading the saved optimizer ' 222 'state. As a result, your model is ' 223 'starting with a freshly initialized ' 224 'optimizer.') 225 finally: 226 if opened_new_file: 227 f.close() 228 return model 229 230 231def preprocess_weights_for_loading(layer, 232 weights, 233 original_keras_version=None, 234 original_backend=None): 235 """Preprocess layer weights between different Keras formats. 236 237 Converts layers weights from Keras 1 format to Keras 2 and also weights of 238 CuDNN layers in Keras 2. 239 240 Args: 241 layer: Layer instance. 242 weights: List of weights values (Numpy arrays). 243 original_keras_version: Keras version for the weights, as a string. 244 original_backend: Keras backend the weights were trained with, 245 as a string. 246 247 Returns: 248 A list of weights values (Numpy arrays). 249 """ 250 def convert_nested_bidirectional(weights): 251 """Converts layers nested in `Bidirectional` wrapper. 252 253 This function uses `preprocess_weights_for_loading()` for converting 254 layers. 255 256 Args: 257 weights: List of weights values (Numpy arrays). 258 259 Returns: 260 A list of weights values (Numpy arrays). 261 """ 262 num_weights_per_layer = len(weights) // 2 263 forward_weights = preprocess_weights_for_loading( 264 layer.forward_layer, weights[:num_weights_per_layer], 265 original_keras_version, original_backend) 266 backward_weights = preprocess_weights_for_loading( 267 layer.backward_layer, weights[num_weights_per_layer:], 268 original_keras_version, original_backend) 269 return forward_weights + backward_weights 270 271 def convert_nested_time_distributed(weights): 272 """Converts layers nested in `TimeDistributed` wrapper. 273 274 This function uses `preprocess_weights_for_loading()` for converting nested 275 layers. 276 277 Args: 278 weights: List of weights values (Numpy arrays). 279 280 Returns: 281 A list of weights values (Numpy arrays). 282 """ 283 return preprocess_weights_for_loading( 284 layer.layer, weights, original_keras_version, original_backend) 285 286 def convert_nested_model(weights): 287 """Converts layers nested in `Model` or `Sequential`. 288 289 This function uses `preprocess_weights_for_loading()` for converting nested 290 layers. 291 292 Args: 293 weights: List of weights values (Numpy arrays). 294 295 Returns: 296 A list of weights values (Numpy arrays). 297 """ 298 trainable_weights = weights[:len(layer.trainable_weights)] 299 non_trainable_weights = weights[len(layer.trainable_weights):] 300 301 new_trainable_weights = [] 302 new_non_trainable_weights = [] 303 304 for sublayer in layer.layers: 305 num_trainable_weights = len(sublayer.trainable_weights) 306 num_non_trainable_weights = len(sublayer.non_trainable_weights) 307 if sublayer.weights: 308 preprocessed = preprocess_weights_for_loading( 309 layer=sublayer, 310 weights=(trainable_weights[:num_trainable_weights] + 311 non_trainable_weights[:num_non_trainable_weights]), 312 original_keras_version=original_keras_version, 313 original_backend=original_backend) 314 new_trainable_weights.extend(preprocessed[:num_trainable_weights]) 315 new_non_trainable_weights.extend(preprocessed[num_trainable_weights:]) 316 317 trainable_weights = trainable_weights[num_trainable_weights:] 318 non_trainable_weights = non_trainable_weights[ 319 num_non_trainable_weights:] 320 321 return new_trainable_weights + new_non_trainable_weights 322 323 # Convert layers nested in Bidirectional/Model/Sequential. 324 # Both transformation should be ran for both Keras 1->2 conversion 325 # and for conversion of CuDNN layers. 326 if layer.__class__.__name__ == 'Bidirectional': 327 weights = convert_nested_bidirectional(weights) 328 if layer.__class__.__name__ == 'TimeDistributed': 329 weights = convert_nested_time_distributed(weights) 330 elif layer.__class__.__name__ in ['Model', 'Sequential']: 331 weights = convert_nested_model(weights) 332 333 if original_keras_version == '1': 334 if layer.__class__.__name__ == 'TimeDistributed': 335 weights = preprocess_weights_for_loading( 336 layer.layer, weights, original_keras_version, original_backend) 337 338 if layer.__class__.__name__ == 'Conv1D': 339 shape = weights[0].shape 340 # Handle Keras 1.1 format 341 if shape[:2] != (layer.kernel_size[0], 1) or shape[3] != layer.filters: 342 # Legacy shape: 343 # (filters, input_dim, filter_length, 1) 344 assert shape[0] == layer.filters and shape[2:] == (layer.kernel_size[0], 345 1) 346 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 347 weights[0] = weights[0][:, 0, :, :] 348 349 if layer.__class__.__name__ == 'Conv2D': 350 if layer.data_format == 'channels_first': 351 # old: (filters, stack_size, kernel_rows, kernel_cols) 352 # new: (kernel_rows, kernel_cols, stack_size, filters) 353 weights[0] = np.transpose(weights[0], (2, 3, 1, 0)) 354 355 if layer.__class__.__name__ == 'Conv2DTranspose': 356 if layer.data_format == 'channels_last': 357 # old: (kernel_rows, kernel_cols, stack_size, filters) 358 # new: (kernel_rows, kernel_cols, filters, stack_size) 359 weights[0] = np.transpose(weights[0], (0, 1, 3, 2)) 360 if layer.data_format == 'channels_first': 361 # old: (filters, stack_size, kernel_rows, kernel_cols) 362 # new: (kernel_rows, kernel_cols, filters, stack_size) 363 weights[0] = np.transpose(weights[0], (2, 3, 0, 1)) 364 365 if layer.__class__.__name__ == 'Conv3D': 366 if layer.data_format == 'channels_first': 367 # old: (filters, stack_size, ...) 368 # new: (..., stack_size, filters) 369 weights[0] = np.transpose(weights[0], (2, 3, 4, 1, 0)) 370 371 if layer.__class__.__name__ == 'GRU': 372 if len(weights) == 9: 373 kernel = np.concatenate([weights[0], weights[3], weights[6]], axis=-1) 374 recurrent_kernel = np.concatenate( 375 [weights[1], weights[4], weights[7]], axis=-1) 376 bias = np.concatenate([weights[2], weights[5], weights[8]], axis=-1) 377 weights = [kernel, recurrent_kernel, bias] 378 379 if layer.__class__.__name__ == 'LSTM': 380 if len(weights) == 12: 381 # old: i, c, f, o 382 # new: i, f, c, o 383 kernel = np.concatenate( 384 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 385 recurrent_kernel = np.concatenate( 386 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 387 bias = np.concatenate( 388 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 389 weights = [kernel, recurrent_kernel, bias] 390 391 if layer.__class__.__name__ == 'ConvLSTM2D': 392 if len(weights) == 12: 393 kernel = np.concatenate( 394 [weights[0], weights[6], weights[3], weights[9]], axis=-1) 395 recurrent_kernel = np.concatenate( 396 [weights[1], weights[7], weights[4], weights[10]], axis=-1) 397 bias = np.concatenate( 398 [weights[2], weights[8], weights[5], weights[11]], axis=-1) 399 if layer.data_format == 'channels_first': 400 # old: (filters, stack_size, kernel_rows, kernel_cols) 401 # new: (kernel_rows, kernel_cols, stack_size, filters) 402 kernel = np.transpose(kernel, (2, 3, 1, 0)) 403 recurrent_kernel = np.transpose(recurrent_kernel, (2, 3, 1, 0)) 404 weights = [kernel, recurrent_kernel, bias] 405 406 conv_layers = ['Conv1D', 'Conv2D', 'Conv3D', 'Conv2DTranspose', 'ConvLSTM2D'] 407 if layer.__class__.__name__ in conv_layers: 408 if K.int_shape(layer.weights[0]) != weights[0].shape: 409 weights[0] = np.transpose(weights[0], (3, 2, 0, 1)) 410 if layer.__class__.__name__ == 'ConvLSTM2D': 411 weights[1] = np.transpose(weights[1], (3, 2, 0, 1)) 412 413 # convert CuDNN layers 414 return _convert_rnn_weights(layer, weights) 415 416 417def _convert_rnn_weights(layer, weights): 418 """Converts weights for RNN layers between native and CuDNN format. 419 420 Input kernels for each gate are transposed and converted between Fortran 421 and C layout, recurrent kernels are transposed. For LSTM biases are summed/ 422 split in half, for GRU biases are reshaped. 423 424 Weights can be converted in both directions between `LSTM` and`CuDNNSLTM` 425 and between `CuDNNGRU` and `GRU(reset_after=True)`. Default `GRU` is not 426 compatible with `CuDNNGRU`. 427 428 For missing biases in `LSTM`/`GRU` (`use_bias=False`) no conversion is made. 429 430 Args: 431 layer: Target layer instance. 432 weights: List of source weights values (input kernels, recurrent 433 kernels, [biases]) (Numpy arrays). 434 435 Returns: 436 A list of converted weights values (Numpy arrays). 437 438 Raises: 439 ValueError: for incompatible GRU layer/weights or incompatible biases 440 """ 441 442 def transform_kernels(kernels, func, n_gates): 443 """Transforms kernel for each gate separately using given function. 444 445 Args: 446 kernels: Stacked array of kernels for individual gates. 447 func: Function applied to kernel of each gate. 448 n_gates: Number of gates (4 for LSTM, 3 for GRU). 449 450 Returns: 451 Stacked array of transformed kernels. 452 """ 453 return np.hstack([func(k) for k in np.hsplit(kernels, n_gates)]) 454 455 def transpose_input(from_cudnn): 456 """Makes a function that transforms input kernels from/to CuDNN format. 457 458 It keeps the shape, but changes between the layout (Fortran/C). Eg.: 459 460 ``` 461 Keras CuDNN 462 [[0, 1, 2], <---> [[0, 2, 4], 463 [3, 4, 5]] [1, 3, 5]] 464 ``` 465 466 It can be passed to `transform_kernels()`. 467 468 Args: 469 from_cudnn: `True` if source weights are in CuDNN format, `False` 470 if they're in plain Keras format. 471 472 Returns: 473 Function that converts input kernel to the other format. 474 """ 475 order = 'F' if from_cudnn else 'C' 476 477 def transform(kernel): 478 return kernel.T.reshape(kernel.shape, order=order) 479 480 return transform 481 482 target_class = layer.__class__.__name__ 483 484 # convert the weights between CuDNNLSTM and LSTM 485 if target_class in ['LSTM', 'CuDNNLSTM'] and len(weights) == 3: 486 # determine if we're loading a CuDNNLSTM layer 487 # from the number of bias weights: 488 # CuDNNLSTM has (units * 8) weights; while LSTM has (units * 4) 489 # if there's no bias weight in the file, skip this conversion 490 units = weights[1].shape[0] 491 bias_shape = weights[2].shape 492 n_gates = 4 493 494 if bias_shape == (2 * units * n_gates,): 495 source = 'CuDNNLSTM' 496 elif bias_shape == (units * n_gates,): 497 source = 'LSTM' 498 else: 499 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 500 501 def convert_lstm_weights(weights, from_cudnn=True): 502 """Converts the weights between CuDNNLSTM and LSTM. 503 504 Args: 505 weights: Original weights. 506 from_cudnn: Indicates whether original weights are from CuDNN layer. 507 508 Returns: 509 Updated weights compatible with LSTM. 510 """ 511 512 # Transpose (and reshape) input and recurrent kernels 513 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 514 n_gates) 515 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 516 if from_cudnn: 517 # merge input and recurrent biases into a single set 518 biases = np.sum(np.split(weights[2], 2, axis=0), axis=0) 519 else: 520 # Split single set of biases evenly to two sets. The way of 521 # splitting doesn't matter as long as the two sets sum is kept. 522 biases = np.tile(0.5 * weights[2], 2) 523 return [kernels, recurrent_kernels, biases] 524 525 if source != target_class: 526 weights = convert_lstm_weights(weights, from_cudnn=source == 'CuDNNLSTM') 527 528 # convert the weights between CuDNNGRU and GRU(reset_after=True) 529 if target_class in ['GRU', 'CuDNNGRU'] and len(weights) == 3: 530 # We can determine the source of the weights from the shape of the bias. 531 # If there is no bias we skip the conversion since 532 # CuDNNGRU always has biases. 533 534 units = weights[1].shape[0] 535 bias_shape = weights[2].shape 536 n_gates = 3 537 538 def convert_gru_weights(weights, from_cudnn=True): 539 """Converts the weights between CuDNNGRU and GRU. 540 541 Args: 542 weights: Original weights. 543 from_cudnn: Indicates whether original weights are from CuDNN layer. 544 545 Returns: 546 Updated weights compatible with GRU. 547 """ 548 549 kernels = transform_kernels(weights[0], transpose_input(from_cudnn), 550 n_gates) 551 recurrent_kernels = transform_kernels(weights[1], lambda k: k.T, n_gates) 552 biases = np.array(weights[2]).reshape((2, -1) if from_cudnn else -1) 553 return [kernels, recurrent_kernels, biases] 554 555 if bias_shape == (2 * units * n_gates,): 556 source = 'CuDNNGRU' 557 elif bias_shape == (2, units * n_gates): 558 source = 'GRU(reset_after=True)' 559 elif bias_shape == (units * n_gates,): 560 source = 'GRU(reset_after=False)' 561 else: 562 raise ValueError('Invalid bias shape: ' + str(bias_shape)) 563 564 if target_class == 'CuDNNGRU': 565 target = 'CuDNNGRU' 566 elif layer.reset_after: 567 target = 'GRU(reset_after=True)' 568 else: 569 target = 'GRU(reset_after=False)' 570 571 # only convert between different types 572 if source != target: 573 types = (source, target) 574 if 'GRU(reset_after=False)' in types: 575 raise ValueError('%s is not compatible with %s' % types) 576 if source == 'CuDNNGRU': 577 weights = convert_gru_weights(weights, from_cudnn=True) 578 elif source == 'GRU(reset_after=True)': 579 weights = convert_gru_weights(weights, from_cudnn=False) 580 581 return weights 582 583 584def save_optimizer_weights_to_hdf5_group(hdf5_group, optimizer): 585 """Saves optimizer weights of a optimizer to a HDF5 group. 586 587 Args: 588 hdf5_group: HDF5 group. 589 optimizer: optimizer instance. 590 """ 591 592 symbolic_weights = getattr(optimizer, 'weights') 593 if symbolic_weights: 594 weights_group = hdf5_group.create_group('optimizer_weights') 595 weight_names = [str(w.name).encode('utf8') for w in symbolic_weights] 596 save_attributes_to_hdf5_group(weights_group, 'weight_names', weight_names) 597 weight_values = K.batch_get_value(symbolic_weights) 598 for name, val in zip(weight_names, weight_values): 599 param_dset = weights_group.create_dataset( 600 name, val.shape, dtype=val.dtype) 601 if not val.shape: 602 # scalar 603 param_dset[()] = val 604 else: 605 param_dset[:] = val 606 607 608def load_optimizer_weights_from_hdf5_group(hdf5_group): 609 """Load optimizer weights from a HDF5 group. 610 611 Args: 612 hdf5_group: A pointer to a HDF5 group. 613 614 Returns: 615 data: List of optimizer weight names. 616 """ 617 weights_group = hdf5_group['optimizer_weights'] 618 optimizer_weight_names = load_attributes_from_hdf5_group( 619 weights_group, 'weight_names') 620 return [weights_group[weight_name] for weight_name in optimizer_weight_names] 621 622 623def save_weights_to_hdf5_group(f, layers): 624 """Saves the weights of a list of layers to a HDF5 group. 625 626 Args: 627 f: HDF5 group. 628 layers: List of layer instances. 629 """ 630 from tensorflow.python.keras import __version__ as keras_version # pylint: disable=g-import-not-at-top 631 632 save_attributes_to_hdf5_group( 633 f, 'layer_names', [layer.name.encode('utf8') for layer in layers]) 634 f.attrs['backend'] = K.backend().encode('utf8') 635 f.attrs['keras_version'] = str(keras_version).encode('utf8') 636 637 # Sort model layers by layer name to ensure that group names are strictly 638 # growing to avoid prefix issues. 639 for layer in sorted(layers, key=lambda x: x.name): 640 g = f.create_group(layer.name) 641 weights = _legacy_weights(layer) 642 weight_values = K.batch_get_value(weights) 643 weight_names = [w.name.encode('utf8') for w in weights] 644 save_attributes_to_hdf5_group(g, 'weight_names', weight_names) 645 for name, val in zip(weight_names, weight_values): 646 param_dset = g.create_dataset(name, val.shape, dtype=val.dtype) 647 if not val.shape: 648 # scalar 649 param_dset[()] = val 650 else: 651 param_dset[:] = val 652 653 654def load_weights_from_hdf5_group(f, layers): 655 """Implements topological (order-based) weight loading. 656 657 Args: 658 f: A pointer to a HDF5 group. 659 layers: a list of target layers. 660 661 Raises: 662 ValueError: in case of mismatch between provided layers 663 and weights file. 664 """ 665 if 'keras_version' in f.attrs: 666 original_keras_version = f.attrs['keras_version'] 667 if hasattr(original_keras_version, 'decode'): 668 original_keras_version = original_keras_version.decode('utf8') 669 else: 670 original_keras_version = '1' 671 if 'backend' in f.attrs: 672 original_backend = f.attrs['backend'] 673 if hasattr(original_backend, 'decode'): 674 original_backend = original_backend.decode('utf8') 675 else: 676 original_backend = None 677 678 filtered_layers = [] 679 for layer in layers: 680 weights = _legacy_weights(layer) 681 if weights: 682 filtered_layers.append(layer) 683 684 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 685 filtered_layer_names = [] 686 for name in layer_names: 687 g = f[name] 688 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 689 if weight_names: 690 filtered_layer_names.append(name) 691 layer_names = filtered_layer_names 692 if len(layer_names) != len(filtered_layers): 693 raise ValueError('You are trying to load a weight file ' 694 'containing ' + str(len(layer_names)) + 695 ' layers into a model with ' + str(len(filtered_layers)) + 696 ' layers.') 697 698 # We batch weight value assignments in a single backend call 699 # which provides a speedup in TensorFlow. 700 weight_value_tuples = [] 701 for k, name in enumerate(layer_names): 702 g = f[name] 703 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 704 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 705 layer = filtered_layers[k] 706 symbolic_weights = _legacy_weights(layer) 707 weight_values = preprocess_weights_for_loading( 708 layer, weight_values, original_keras_version, original_backend) 709 if len(weight_values) != len(symbolic_weights): 710 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 711 '" in the current model) was found to ' 712 'correspond to layer ' + name + ' in the save file. ' 713 'However the new layer ' + layer.name + ' expects ' + 714 str(len(symbolic_weights)) + 715 ' weights, but the saved weights have ' + 716 str(len(weight_values)) + ' elements.') 717 weight_value_tuples += zip(symbolic_weights, weight_values) 718 K.batch_set_value(weight_value_tuples) 719 720 721def load_weights_from_hdf5_group_by_name( 722 f, layers, skip_mismatch=False): 723 """Implements name-based weight loading. 724 725 (instead of topological weight loading). 726 727 Layers that have no matching name are skipped. 728 729 Args: 730 f: A pointer to a HDF5 group. 731 layers: a list of target layers. 732 skip_mismatch: Boolean, whether to skip loading of layers 733 where there is a mismatch in the number of weights, 734 or a mismatch in the shape of the weights. 735 736 Raises: 737 ValueError: in case of mismatch between provided layers 738 and weights file and skip_match=False. 739 """ 740 if 'keras_version' in f.attrs: 741 original_keras_version = f.attrs['keras_version'] 742 if hasattr(original_keras_version, 'decode'): 743 original_keras_version = original_keras_version.decode('utf8') 744 else: 745 original_keras_version = '1' 746 if 'backend' in f.attrs: 747 original_backend = f.attrs['backend'] 748 if hasattr(original_backend, 'decode'): 749 original_backend = original_backend.decode('utf8') 750 else: 751 original_backend = None 752 753 # New file format. 754 layer_names = load_attributes_from_hdf5_group(f, 'layer_names') 755 756 # Reverse index of layer name to list of layers with name. 757 index = {} 758 for layer in layers: 759 if layer.name: 760 index.setdefault(layer.name, []).append(layer) 761 762 # We batch weight value assignments in a single backend call 763 # which provides a speedup in TensorFlow. 764 weight_value_tuples = [] 765 for k, name in enumerate(layer_names): 766 g = f[name] 767 weight_names = load_attributes_from_hdf5_group(g, 'weight_names') 768 weight_values = [np.asarray(g[weight_name]) for weight_name in weight_names] 769 770 for layer in index.get(name, []): 771 symbolic_weights = _legacy_weights(layer) 772 weight_values = preprocess_weights_for_loading( 773 layer, weight_values, original_keras_version, original_backend) 774 if len(weight_values) != len(symbolic_weights): 775 if skip_mismatch: 776 logging.warning('Skipping loading of weights for ' 777 'layer {}'.format(layer.name) + ' due to mismatch ' 778 'in number of weights ({} vs {}).'.format( 779 len(symbolic_weights), len(weight_values))) 780 continue 781 raise ValueError('Layer #' + str(k) + ' (named "' + layer.name + 782 '") expects ' + str(len(symbolic_weights)) + 783 ' weight(s), but the saved weights' + ' have ' + 784 str(len(weight_values)) + ' element(s).') 785 # Set values. 786 for i in range(len(weight_values)): 787 if K.int_shape(symbolic_weights[i]) != weight_values[i].shape: 788 if skip_mismatch: 789 logging.warning('Skipping loading of weights for ' 790 'layer {}'.format(layer.name) + ' due to ' 791 'mismatch in shape ({} vs {}).'.format( 792 symbolic_weights[i].shape, 793 weight_values[i].shape)) 794 continue 795 raise ValueError('Layer #' + str(k) +' (named "' + layer.name + 796 '"), weight ' + str(symbolic_weights[i]) + 797 ' has shape {}'.format(K.int_shape( 798 symbolic_weights[i])) + 799 ', but the saved weight has shape ' + 800 str(weight_values[i].shape) + '.') 801 802 else: 803 weight_value_tuples.append((symbolic_weights[i], weight_values[i])) 804 K.batch_set_value(weight_value_tuples) 805 806 807def save_attributes_to_hdf5_group(group, name, data): 808 """Saves attributes (data) of the specified name into the HDF5 group. 809 810 This method deals with an inherent problem of HDF5 file which is not 811 able to store data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 812 813 Args: 814 group: A pointer to a HDF5 group. 815 name: A name of the attributes to save. 816 data: Attributes data to store. 817 818 Raises: 819 RuntimeError: If any single attribute is too large to be saved. 820 """ 821 # Check that no item in `data` is larger than `HDF5_OBJECT_HEADER_LIMIT` 822 # because in that case even chunking the array would not make the saving 823 # possible. 824 bad_attributes = [x for x in data if len(x) > HDF5_OBJECT_HEADER_LIMIT] 825 826 # Expecting this to never be true. 827 if bad_attributes: 828 raise RuntimeError('The following attributes cannot be saved to HDF5 ' 829 'file because they are larger than %d bytes: %s' % 830 (HDF5_OBJECT_HEADER_LIMIT, ', '.join(bad_attributes))) 831 832 data_npy = np.asarray(data) 833 834 num_chunks = 1 835 chunked_data = np.array_split(data_npy, num_chunks) 836 837 # This will never loop forever thanks to the test above. 838 while any(x.nbytes > HDF5_OBJECT_HEADER_LIMIT for x in chunked_data): 839 num_chunks += 1 840 chunked_data = np.array_split(data_npy, num_chunks) 841 842 if num_chunks > 1: 843 for chunk_id, chunk_data in enumerate(chunked_data): 844 group.attrs['%s%d' % (name, chunk_id)] = chunk_data 845 else: 846 group.attrs[name] = data 847 848 849def load_attributes_from_hdf5_group(group, name): 850 """Loads attributes of the specified name from the HDF5 group. 851 852 This method deals with an inherent problem 853 of HDF5 file which is not able to store 854 data larger than HDF5_OBJECT_HEADER_LIMIT bytes. 855 856 Args: 857 group: A pointer to a HDF5 group. 858 name: A name of the attributes to load. 859 860 Returns: 861 data: Attributes data. 862 """ 863 if name in group.attrs: 864 data = [ 865 n.decode('utf8') if hasattr(n, 'decode') else n 866 for n in group.attrs[name] 867 ] 868 else: 869 data = [] 870 chunk_id = 0 871 while '%s%d' % (name, chunk_id) in group.attrs: 872 data.extend([ 873 n.decode('utf8') if hasattr(n, 'decode') else n 874 for n in group.attrs['%s%d' % (name, chunk_id)] 875 ]) 876 chunk_id += 1 877 return data 878 879 880def _legacy_weights(layer): 881 """DO NOT USE. 882 883 For legacy reason, the layer.weights was in the order of 884 [self.trainable_weights + self.non_trainable_weights], and this order was 885 used for preserving the weights in h5 format. The new order of layer.weights 886 are the same as layer.get_weights() which is more intuitive for user. To 887 keep supporting the existing saved h5 file, this method should be used to 888 save/load weights. In future version, we will delete this method and 889 introduce a breaking change for h5 and stay with the new order for weights. 890 891 Args: 892 layer: a `tf.keras.Model` or `tf.keras.layers.Layer` instance. 893 894 Returns: 895 A list of variables with the order of trainable_weights, followed by 896 non_trainable_weights. 897 """ 898 weights = layer.trainable_weights + layer.non_trainable_weights 899 if any(not isinstance(w, variables_module.Variable) for w in weights): 900 raise NotImplementedError( 901 'Save or restore weights that is not an instance of `tf.Variable` is ' 902 'not supported in h5, use `save_format=\'tf\'` instead. Got a model ' 903 'or layer {} with weights {}'.format(layer.__class__.__name__, weights)) 904 return weights 905