1# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ==============================================================================
15# pylint: disable=protected-access
16"""Utilities related to layer/model functionality.
17"""
18from __future__ import absolute_import
19from __future__ import division
20from __future__ import print_function
21
22import numpy as np
23
24from tensorflow.python.keras import backend as K
25from tensorflow.python.keras.utils.conv_utils import convert_kernel
26from tensorflow.python.util import nest
27from tensorflow.python.util.tf_export import keras_export
28
29
30@keras_export('keras.utils.get_source_inputs')
31def get_source_inputs(tensor, layer=None, node_index=None):
32  """Returns the list of input tensors necessary to compute `tensor`.
33
34  Output will always be a list of tensors
35  (potentially with 1 element).
36
37  Arguments:
38      tensor: The tensor to start from.
39      layer: Origin layer of the tensor. Will be
40          determined via tensor._keras_history if not provided.
41      node_index: Origin node index of the tensor.
42
43  Returns:
44      List of input tensors.
45  """
46  if not hasattr(tensor, '_keras_history'):
47    return tensor
48
49  if layer is None or node_index:
50    layer, node_index, _ = tensor._keras_history
51  if not layer._inbound_nodes:
52    return [tensor]
53  else:
54    node = layer._inbound_nodes[node_index]
55    if not node.inbound_layers:
56      # Reached an Input layer, stop recursion.
57      return nest.flatten(node.input_tensors)
58    else:
59      source_tensors = []
60      for layer, node_index, _, tensor in node.iterate_inbound():
61        previous_sources = get_source_inputs(tensor, layer, node_index)
62        # Avoid input redundancy.
63        for x in previous_sources:
64          if x not in source_tensors:
65            source_tensors.append(x)
66      return source_tensors
67
68
69def count_params(weights):
70  """Count the total number of scalars composing the weights.
71
72  Arguments:
73      weights: An iterable containing the weights on which to compute params
74
75  Returns:
76      The total number of scalars composing the weights
77  """
78  return int(sum(np.prod(p.get_shape().as_list()) for p in set(weights)))
79
80
81def print_summary(model, line_length=None, positions=None, print_fn=None):
82  """Prints a summary of a model.
83
84  Arguments:
85      model: Keras model instance.
86      line_length: Total length of printed lines
87          (e.g. set this to adapt the display to different
88          terminal window sizes).
89      positions: Relative or absolute positions of log elements in each line.
90          If not provided, defaults to `[.33, .55, .67, 1.]`.
91      print_fn: Print function to use.
92          It will be called on each line of the summary.
93          You can set it to a custom function
94          in order to capture the string summary.
95          It defaults to `print` (prints to stdout).
96  """
97  if print_fn is None:
98    print_fn = print
99
100  if model.__class__.__name__ == 'Sequential':
101    sequential_like = True
102  elif not model._is_graph_network:
103    # We treat subclassed models as a simple sequence of layers, for logging
104    # purposes.
105    sequential_like = True
106  else:
107    sequential_like = True
108    nodes_by_depth = model._nodes_by_depth.values()
109    nodes = []
110    for v in nodes_by_depth:
111      if (len(v) > 1) or (len(v) == 1 and
112                          len(nest.flatten(v[0].inbound_layers)) > 1):
113        # if the model has multiple nodes
114        # or if the nodes have multiple inbound_layers
115        # the model is no longer sequential
116        sequential_like = False
117        break
118      nodes += v
119    if sequential_like:
120      # search for shared layers
121      for layer in model.layers:
122        flag = False
123        for node in layer._inbound_nodes:
124          if node in nodes:
125            if flag:
126              sequential_like = False
127              break
128            else:
129              flag = True
130        if not sequential_like:
131          break
132
133  if sequential_like:
134    line_length = line_length or 65
135    positions = positions or [.45, .85, 1.]
136    if positions[-1] <= 1:
137      positions = [int(line_length * p) for p in positions]
138    # header names for the different log elements
139    to_display = ['Layer (type)', 'Output Shape', 'Param #']
140  else:
141    line_length = line_length or 98
142    positions = positions or [.33, .55, .67, 1.]
143    if positions[-1] <= 1:
144      positions = [int(line_length * p) for p in positions]
145    # header names for the different log elements
146    to_display = ['Layer (type)', 'Output Shape', 'Param #', 'Connected to']
147    relevant_nodes = []
148    for v in model._nodes_by_depth.values():
149      relevant_nodes += v
150
151  def print_row(fields, positions):
152    line = ''
153    for i in range(len(fields)):
154      if i > 0:
155        line = line[:-1] + ' '
156      line += str(fields[i])
157      line = line[:positions[i]]
158      line += ' ' * (positions[i] - len(line))
159    print_fn(line)
160
161  print_fn('Model: "{}"'.format(model.name))
162  print_fn('_' * line_length)
163  print_row(to_display, positions)
164  print_fn('=' * line_length)
165
166  def print_layer_summary(layer):
167    """Prints a summary for a single layer.
168
169    Arguments:
170        layer: target layer.
171    """
172    try:
173      output_shape = layer.output_shape
174    except AttributeError:
175      output_shape = 'multiple'
176    except RuntimeError:  # output_shape unknown in Eager mode.
177      output_shape = '?'
178    name = layer.name
179    cls_name = layer.__class__.__name__
180    fields = [name + ' (' + cls_name + ')', output_shape, layer.count_params()]
181    print_row(fields, positions)
182
183  def print_layer_summary_with_connections(layer):
184    """Prints a summary for a single layer (including topological connections).
185
186    Arguments:
187        layer: target layer.
188    """
189    try:
190      output_shape = layer.output_shape
191    except AttributeError:
192      output_shape = 'multiple'
193    connections = []
194    for node in layer._inbound_nodes:
195      if relevant_nodes and node not in relevant_nodes:
196        # node is not part of the current network
197        continue
198
199      for inbound_layer, node_index, tensor_index, _ in node.iterate_inbound():
200        connections.append('{}[{}][{}]'.format(inbound_layer.name, node_index,
201                                               tensor_index))
202
203    name = layer.name
204    cls_name = layer.__class__.__name__
205    if not connections:
206      first_connection = ''
207    else:
208      first_connection = connections[0]
209    fields = [
210        name + ' (' + cls_name + ')', output_shape,
211        layer.count_params(), first_connection
212    ]
213    print_row(fields, positions)
214    if len(connections) > 1:
215      for i in range(1, len(connections)):
216        fields = ['', '', '', connections[i]]
217        print_row(fields, positions)
218
219  layers = model.layers
220  for i in range(len(layers)):
221    if sequential_like:
222      print_layer_summary(layers[i])
223    else:
224      print_layer_summary_with_connections(layers[i])
225    if i == len(layers) - 1:
226      print_fn('=' * line_length)
227    else:
228      print_fn('_' * line_length)
229
230  model._check_trainable_weights_consistency()
231  if hasattr(model, '_collected_trainable_weights'):
232    trainable_count = count_params(model._collected_trainable_weights)
233  else:
234    trainable_count = count_params(model.trainable_weights)
235
236  non_trainable_count = count_params(model.non_trainable_weights)
237
238  print_fn('Total params: {:,}'.format(trainable_count + non_trainable_count))
239  print_fn('Trainable params: {:,}'.format(trainable_count))
240  print_fn('Non-trainable params: {:,}'.format(non_trainable_count))
241  print_fn('_' * line_length)
242
243
244def gather_trainable_weights(trainable, sub_layers, extra_variables):
245  """Lists the trainable weights for an object with sub-layers.
246
247  Args:
248    trainable: Whether the object collecting the variables is trainable.
249    sub_layers: A flat list of Layer objects owned by this object, to collect
250      variables from.
251    extra_variables: Any extra variables to include. Their `.trainable` property
252      is used to categorize them.
253
254  Returns:
255    A list of collected trainable weights/variables.
256  """
257  if not trainable:
258    return []
259  weights = []
260  for layer in sub_layers:
261    weights += layer.trainable_weights
262  trainable_extra_variables = [
263      v for v in extra_variables if v.trainable]
264  return weights + trainable_extra_variables
265
266
267def gather_non_trainable_weights(trainable, sub_layers, extra_variables):
268  """Lists the non-trainable weights for an object with sub-layers.
269
270  Args:
271    trainable: Whether the object collecting the variables is trainable.
272    sub_layers: A flat list of Layer objects owned by this object, to collect
273      variables from.
274    extra_variables: Any extra variables to include. Their `.trainable` property
275      is used to categorize them.
276
277  Returns:
278    A list of collected non-trainable weights/variables.
279  """
280  trainable_extra_variables = []
281  non_trainable_extra_variables = []
282  for v in extra_variables:
283    if v.trainable:
284      trainable_extra_variables.append(v)
285    else:
286      non_trainable_extra_variables.append(v)
287  weights = []
288  for layer in sub_layers:
289    weights += layer.non_trainable_weights
290  if not trainable:
291    trainable_weights = []
292    for layer in sub_layers:
293      trainable_weights += layer.trainable_weights
294    return (trainable_weights + trainable_extra_variables
295            + weights + non_trainable_extra_variables)
296  return weights + non_trainable_extra_variables
297
298
299@keras_export('keras.utils.convert_all_kernels_in_model')
300def convert_all_kernels_in_model(model):
301  """Converts all convolution kernels in a model from Theano to TensorFlow.
302
303  Also works from TensorFlow to Theano.
304
305  Arguments:
306      model: target model for the conversion.
307  """
308  # Note: SeparableConvolution not included
309  # since only supported by TF.
310  conv_classes = {
311      'Conv1D',
312      'Conv2D',
313      'Conv3D',
314      'Conv2DTranspose',
315  }
316  to_assign = []
317  for layer in model.layers:
318    if layer.__class__.__name__ in conv_classes:
319      original_kernel = K.get_value(layer.kernel)
320      converted_kernel = convert_kernel(original_kernel)
321      to_assign.append((layer.kernel, converted_kernel))
322  K.batch_set_value(to_assign)
323
324
325def convert_dense_weights_data_format(dense,
326                                      previous_feature_map_shape,
327                                      target_data_format='channels_first'):
328  """Utility useful when changing a convnet's `data_format`.
329
330  When porting the weights of a convnet from one data format to the other,
331  if the convnet includes a `Flatten` layer
332  (applied to the last convolutional feature map)
333  followed by a `Dense` layer, the weights of that `Dense` layer
334  should be updated to reflect the new dimension ordering.
335
336  Arguments:
337      dense: The target `Dense` layer.
338      previous_feature_map_shape: A shape tuple of 3 integers,
339          e.g. `(512, 7, 7)`. The shape of the convolutional
340          feature map right before the `Flatten` layer that
341          came before the target `Dense` layer.
342      target_data_format: One of "channels_last", "channels_first".
343          Set it "channels_last"
344          if converting a "channels_first" model to "channels_last",
345          or reciprocally.
346  """
347  assert target_data_format in {'channels_last', 'channels_first'}
348  kernel, bias = dense.get_weights()
349  for i in range(kernel.shape[1]):
350    if target_data_format == 'channels_first':
351      c, h, w = previous_feature_map_shape
352      original_fm_shape = (h, w, c)
353      ki = kernel[:, i].reshape(original_fm_shape)
354      ki = np.transpose(ki, (2, 0, 1))  # last -> first
355    else:
356      h, w, c = previous_feature_map_shape
357      original_fm_shape = (c, h, w)
358      ki = kernel[:, i].reshape(original_fm_shape)
359      ki = np.transpose(ki, (1, 2, 0))  # first -> last
360    kernel[:, i] = np.reshape(ki, (np.prod(previous_feature_map_shape),))
361  dense.set_weights([kernel, bias])
362
363
364def is_builtin_layer(layer):
365  if not getattr(layer, '_keras_api_names', None):
366    return False
367
368  # Subclasses of `Layer` that are not exported inherit the export name
369  # of the base layer class.
370  return (layer._keras_api_names != ('keras.layers.Layer',) and
371          layer._keras_api_names_v1 != ('keras.layers.Layer',))
372