1# Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================== 15# pylint: disable=protected-access 16"""Utilities related to layer/model functionality. 17""" 18from __future__ import absolute_import 19from __future__ import division 20from __future__ import print_function 21 22import numpy as np 23 24from tensorflow.python.keras import backend as K 25from tensorflow.python.keras.utils.conv_utils import convert_kernel 26from tensorflow.python.util import nest 27from tensorflow.python.util.tf_export import keras_export 28 29 30@keras_export('keras.utils.get_source_inputs') 31def get_source_inputs(tensor, layer=None, node_index=None): 32 """Returns the list of input tensors necessary to compute `tensor`. 33 34 Output will always be a list of tensors 35 (potentially with 1 element). 36 37 Arguments: 38 tensor: The tensor to start from. 39 layer: Origin layer of the tensor. Will be 40 determined via tensor._keras_history if not provided. 41 node_index: Origin node index of the tensor. 42 43 Returns: 44 List of input tensors. 45 """ 46 if not hasattr(tensor, '_keras_history'): 47 return tensor 48 49 if layer is None or node_index: 50 layer, node_index, _ = tensor._keras_history 51 if not layer._inbound_nodes: 52 return [tensor] 53 else: 54 node = layer._inbound_nodes[node_index] 55 if not node.inbound_layers: 56 # Reached an Input layer, stop recursion. 57 return nest.flatten(node.input_tensors) 58 else: 59 source_tensors = [] 60 for layer, node_index, _, tensor in node.iterate_inbound(): 61 previous_sources = get_source_inputs(tensor, layer, node_index) 62 # Avoid input redundancy. 63 for x in previous_sources: 64 if x not in source_tensors: 65 source_tensors.append(x) 66 return source_tensors 67 68 69def count_params(weights): 70 """Count the total number of scalars composing the weights. 71 72 Arguments: 73 weights: An iterable containing the weights on which to compute params 74 75 Returns: 76 The total number of scalars composing the weights 77 """ 78 return int(sum(np.prod(p.get_shape().as_list()) for p in set(weights))) 79 80 81def print_summary(model, line_length=None, positions=None, print_fn=None): 82 """Prints a summary of a model. 83 84 Arguments: 85 model: Keras model instance. 86 line_length: Total length of printed lines 87 (e.g. set this to adapt the display to different 88 terminal window sizes). 89 positions: Relative or absolute positions of log elements in each line. 90 If not provided, defaults to `[.33, .55, .67, 1.]`. 91 print_fn: Print function to use. 92 It will be called on each line of the summary. 93 You can set it to a custom function 94 in order to capture the string summary. 95 It defaults to `print` (prints to stdout). 96 """ 97 if print_fn is None: 98 print_fn = print 99 100 if model.__class__.__name__ == 'Sequential': 101 sequential_like = True 102 elif not model._is_graph_network: 103 # We treat subclassed models as a simple sequence of layers, for logging 104 # purposes. 105 sequential_like = True 106 else: 107 sequential_like = True 108 nodes_by_depth = model._nodes_by_depth.values() 109 nodes = [] 110 for v in nodes_by_depth: 111 if (len(v) > 1) or (len(v) == 1 and 112 len(nest.flatten(v[0].inbound_layers)) > 1): 113 # if the model has multiple nodes 114 # or if the nodes have multiple inbound_layers 115 # the model is no longer sequential 116 sequential_like = False 117 break 118 nodes += v 119 if sequential_like: 120 # search for shared layers 121 for layer in model.layers: 122 flag = False 123 for node in layer._inbound_nodes: 124 if node in nodes: 125 if flag: 126 sequential_like = False 127 break 128 else: 129 flag = True 130 if not sequential_like: 131 break 132 133 if sequential_like: 134 line_length = line_length or 65 135 positions = positions or [.45, .85, 1.] 136 if positions[-1] <= 1: 137 positions = [int(line_length * p) for p in positions] 138 # header names for the different log elements 139 to_display = ['Layer (type)', 'Output Shape', 'Param #'] 140 else: 141 line_length = line_length or 98 142 positions = positions or [.33, .55, .67, 1.] 143 if positions[-1] <= 1: 144 positions = [int(line_length * p) for p in positions] 145 # header names for the different log elements 146 to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to'] 147 relevant_nodes = [] 148 for v in model._nodes_by_depth.values(): 149 relevant_nodes += v 150 151 def print_row(fields, positions): 152 line = '' 153 for i in range(len(fields)): 154 if i > 0: 155 line = line[:-1] + ' ' 156 line += str(fields[i]) 157 line = line[:positions[i]] 158 line += ' ' * (positions[i] - len(line)) 159 print_fn(line) 160 161 print_fn('Model: "{}"'.format(model.name)) 162 print_fn('_' * line_length) 163 print_row(to_display, positions) 164 print_fn('=' * line_length) 165 166 def print_layer_summary(layer): 167 """Prints a summary for a single layer. 168 169 Arguments: 170 layer: target layer. 171 """ 172 try: 173 output_shape = layer.output_shape 174 except AttributeError: 175 output_shape = 'multiple' 176 except RuntimeError: # output_shape unknown in Eager mode. 177 output_shape = '?' 178 name = layer.name 179 cls_name = layer.__class__.__name__ 180 fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()] 181 print_row(fields, positions) 182 183 def print_layer_summary_with_connections(layer): 184 """Prints a summary for a single layer (including topological connections). 185 186 Arguments: 187 layer: target layer. 188 """ 189 try: 190 output_shape = layer.output_shape 191 except AttributeError: 192 output_shape = 'multiple' 193 connections = [] 194 for node in layer._inbound_nodes: 195 if relevant_nodes and node not in relevant_nodes: 196 # node is not part of the current network 197 continue 198 199 for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound(): 200 connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index, 201 tensor_index)) 202 203 name = layer.name 204 cls_name = layer.__class__.__name__ 205 if not connections: 206 first_connection = '' 207 else: 208 first_connection = connections[0] 209 fields = [ 210 name + ' (' + cls_name + ')', output_shape, 211 layer.count_params(), first_connection 212 ] 213 print_row(fields, positions) 214 if len(connections) > 1: 215 for i in range(1, len(connections)): 216 fields = ['', '', '', connections[i]] 217 print_row(fields, positions) 218 219 layers = model.layers 220 for i in range(len(layers)): 221 if sequential_like: 222 print_layer_summary(layers[i]) 223 else: 224 print_layer_summary_with_connections(layers[i]) 225 if i == len(layers) - 1: 226 print_fn('=' * line_length) 227 else: 228 print_fn('_' * line_length) 229 230 model._check_trainable_weights_consistency() 231 if hasattr(model, '_collected_trainable_weights'): 232 trainable_count = count_params(model._collected_trainable_weights) 233 else: 234 trainable_count = count_params(model.trainable_weights) 235 236 non_trainable_count = count_params(model.non_trainable_weights) 237 238 print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count)) 239 print_fn('Trainable params: {:,}'.format(trainable_count)) 240 print_fn('Non-trainable params: {:,}'.format(non_trainable_count)) 241 print_fn('_' * line_length) 242 243 244def gather_trainable_weights(trainable, sub_layers, extra_variables): 245 """Lists the trainable weights for an object with sub-layers. 246 247 Args: 248 trainable: Whether the object collecting the variables is trainable. 249 sub_layers: A flat list of Layer objects owned by this object, to collect 250 variables from. 251 extra_variables: Any extra variables to include. Their `.trainable` property 252 is used to categorize them. 253 254 Returns: 255 A list of collected trainable weights/variables. 256 """ 257 if not trainable: 258 return [] 259 weights = [] 260 for layer in sub_layers: 261 weights += layer.trainable_weights 262 trainable_extra_variables = [ 263 v for v in extra_variables if v.trainable] 264 return weights + trainable_extra_variables 265 266 267def gather_non_trainable_weights(trainable, sub_layers, extra_variables): 268 """Lists the non-trainable weights for an object with sub-layers. 269 270 Args: 271 trainable: Whether the object collecting the variables is trainable. 272 sub_layers: A flat list of Layer objects owned by this object, to collect 273 variables from. 274 extra_variables: Any extra variables to include. Their `.trainable` property 275 is used to categorize them. 276 277 Returns: 278 A list of collected non-trainable weights/variables. 279 """ 280 trainable_extra_variables = [] 281 non_trainable_extra_variables = [] 282 for v in extra_variables: 283 if v.trainable: 284 trainable_extra_variables.append(v) 285 else: 286 non_trainable_extra_variables.append(v) 287 weights = [] 288 for layer in sub_layers: 289 weights += layer.non_trainable_weights 290 if not trainable: 291 trainable_weights = [] 292 for layer in sub_layers: 293 trainable_weights += layer.trainable_weights 294 return (trainable_weights + trainable_extra_variables 295 + weights + non_trainable_extra_variables) 296 return weights + non_trainable_extra_variables 297 298 299@keras_export('keras.utils.convert_all_kernels_in_model') 300def convert_all_kernels_in_model(model): 301 """Converts all convolution kernels in a model from Theano to TensorFlow. 302 303 Also works from TensorFlow to Theano. 304 305 Arguments: 306 model: target model for the conversion. 307 """ 308 # Note: SeparableConvolution not included 309 # since only supported by TF. 310 conv_classes = { 311 'Conv1D', 312 'Conv2D', 313 'Conv3D', 314 'Conv2DTranspose', 315 } 316 to_assign = [] 317 for layer in model.layers: 318 if layer.__class__.__name__ in conv_classes: 319 original_kernel = K.get_value(layer.kernel) 320 converted_kernel = convert_kernel(original_kernel) 321 to_assign.append((layer.kernel, converted_kernel)) 322 K.batch_set_value(to_assign) 323 324 325def convert_dense_weights_data_format(dense, 326 previous_feature_map_shape, 327 target_data_format='channels_first'): 328 """Utility useful when changing a convnet's `data_format`. 329 330 When porting the weights of a convnet from one data format to the other, 331 if the convnet includes a `Flatten` layer 332 (applied to the last convolutional feature map) 333 followed by a `Dense` layer, the weights of that `Dense` layer 334 should be updated to reflect the new dimension ordering. 335 336 Arguments: 337 dense: The target `Dense` layer. 338 previous_feature_map_shape: A shape tuple of 3 integers, 339 e.g. `(512, 7, 7)`. The shape of the convolutional 340 feature map right before the `Flatten` layer that 341 came before the target `Dense` layer. 342 target_data_format: One of "channels_last", "channels_first". 343 Set it "channels_last" 344 if converting a "channels_first" model to "channels_last", 345 or reciprocally. 346 """ 347 assert target_data_format in {'channels_last', 'channels_first'} 348 kernel, bias = dense.get_weights() 349 for i in range(kernel.shape[1]): 350 if target_data_format == 'channels_first': 351 c, h, w = previous_feature_map_shape 352 original_fm_shape = (h, w, c) 353 ki = kernel[:, i].reshape(original_fm_shape) 354 ki = np.transpose(ki, (2, 0, 1)) # last -> first 355 else: 356 h, w, c = previous_feature_map_shape 357 original_fm_shape = (c, h, w) 358 ki = kernel[:, i].reshape(original_fm_shape) 359 ki = np.transpose(ki, (1, 2, 0)) # first -> last 360 kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),)) 361 dense.set_weights([kernel, bias]) 362 363 364def is_builtin_layer(layer): 365 if not getattr(layer, '_keras_api_names', None): 366 return False 367 368 # Subclasses of `Layer` that are not exported inherit the export name 369 # of the base layer class. 370 return (layer._keras_api_names != ('keras.layers.Layer',) and 371 layer._keras_api_names_v1 != ('keras.layers.Layer',)) 372